diff --git a/test/test_transforms.py b/test/test_transforms.py index 8cca3085e0..3fe615437a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -326,6 +326,18 @@ def test_scriptmodule_Resample(self): _test_script_module(transforms.Spectrogram, tensor, sample_rate, sample_rate_2) + def test_batch_Resample(self): + waveform = torch.randn(2, 2786) + + # Single then transform then batch + expected = transforms.Resample()(waveform).repeat(3, 1, 1) + + # Batch then transform + computed = transforms.Resample()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def test_scriptmodule_ComplexNorm(self): tensor = torch.rand((1, 2, 201, 2)) _test_script_module(transforms.ComplexNorm, tensor) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 10390e6651..707252a9f3 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -445,7 +445,17 @@ def forward(self, waveform): torch.Tensor: Output signal of dimension (..., time) """ if self.resampling_method == 'sinc_interpolation': - return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) + + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + + waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) + + # unpack batch + waveform = waveform.view(shape[:-1] + waveform.shape[-1:]) + + return waveform raise ValueError('Invalid resampling method: %s' % (self.resampling_method))