diff --git a/test/test_jit.py b/test/test_jit.py index 2a11a7d480..05bb38d0d9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,28 +30,6 @@ def _test_script_module(self, tensor, f, *args): self.assertTrue(torch.allclose(jit_out, py_out)) - def test_torchscript_pad_trim(self): - @torch.jit.script - def jit_method(tensor, max_len, fill_value): - # type: (Tensor, int, float) -> Tensor - return F.pad_trim(tensor, max_len, fill_value) - - tensor = torch.rand((1, 10)) - max_len = 5 - fill_value = 3. - - jit_out = jit_method(tensor, max_len, fill_value) - py_out = F.pad_trim(tensor, max_len, fill_value) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_pad_trim(self): - tensor = torch.rand((1, 10), device="cuda") - max_len = 5 - - self._test_script_module(tensor, transforms.PadTrim, max_len) - def test_torchscript_spectrogram(self): @torch.jit.script def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): diff --git a/test/test_transforms.py b/test/test_transforms.py index 5bdd4e94be..57095f1fc3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -36,20 +36,6 @@ def scale(self, waveform, factor=float(2**31)): waveform = waveform.to(torch.get_default_dtype()) return waveform / factor - def test_pad_trim(self): - - waveform = self.waveform.clone() - length_orig = waveform.size(1) - length_new = int(length_orig * 1.2) - - result = transforms.PadTrim(max_len=length_new)(waveform) - self.assertEqual(result.size(1), length_new) - - length_new = int(length_orig * 0.8) - - result = transforms.PadTrim(max_len=length_new)(waveform) - self.assertEqual(result.size(1), length_new) - def test_mu_law_companding(self): quantization_channels = 256 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 92e712bc04..f24b2d818b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -3,7 +3,6 @@ __all__ = [ - 'pad_trim', 'istft', 'spectrogram', 'create_fb_matrix', @@ -18,28 +17,6 @@ ] -@torch.jit.script -def pad_trim(waveform, max_len, fill_value): - # type: (Tensor, int, float) -> Tensor - r"""Pad/trim a 2D tensor - - Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) - max_len (int): Length to which the waveform will be padded - fill_value (float): Value to fill in - - Returns: - torch.Tensor: Padded/trimmed tensor - """ - n = waveform.size(1) - if max_len > n: - # TODO add "with torch.no_grad():" back when JIT supports it - waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value) - else: - waveform = waveform[:, :max_len] - return waveform - - # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 9ec311e4da..d5c85a9f1b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -7,32 +7,6 @@ from .compliance import kaldi -class PadTrim(torch.jit.ScriptModule): - r"""Pad/Trim a 2D tensor - - Args: - max_len (int): Length to which the waveform will be padded - fill_value (float): Value to fill in - """ - __constants__ = ['max_len', 'fill_value'] - - def __init__(self, max_len, fill_value=0.): - super(PadTrim, self).__init__() - self.max_len = max_len - self.fill_value = fill_value - - @torch.jit.script_method - def forward(self, waveform): - r""" - Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) - - Returns: - Tensor: Tensor of size (c, `max_len`) - """ - return F.pad_trim(waveform, self.max_len, self.fill_value) - - class Spectrogram(torch.jit.ScriptModule): r"""Create a spectrogram from a audio signal