Skip to content

Commit 289f08a

Browse files
jamarshoncpuhrsch
authored andcommitted
more (#160)
1 parent 6a43e9e commit 289f08a

File tree

4 files changed

+0
-85
lines changed

4 files changed

+0
-85
lines changed

test/test_jit.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,6 @@ def _test_script_module(self, tensor, f, *args):
3030

3131
self.assertTrue(torch.allclose(jit_out, py_out))
3232

33-
def test_torchscript_pad_trim(self):
34-
@torch.jit.script
35-
def jit_method(tensor, max_len, fill_value):
36-
# type: (Tensor, int, float) -> Tensor
37-
return F.pad_trim(tensor, max_len, fill_value)
38-
39-
tensor = torch.rand((1, 10))
40-
max_len = 5
41-
fill_value = 3.
42-
43-
jit_out = jit_method(tensor, max_len, fill_value)
44-
py_out = F.pad_trim(tensor, max_len, fill_value)
45-
46-
self.assertTrue(torch.allclose(jit_out, py_out))
47-
48-
@unittest.skipIf(not RUN_CUDA, "no CUDA")
49-
def test_scriptmodule_pad_trim(self):
50-
tensor = torch.rand((1, 10), device="cuda")
51-
max_len = 5
52-
53-
self._test_script_module(tensor, transforms.PadTrim, max_len)
54-
5533
def test_torchscript_spectrogram(self):
5634
@torch.jit.script
5735
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):

test/test_transforms.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,6 @@ def scale(self, waveform, factor=float(2**31)):
3636
waveform = waveform.to(torch.get_default_dtype())
3737
return waveform / factor
3838

39-
def test_pad_trim(self):
40-
41-
waveform = self.waveform.clone()
42-
length_orig = waveform.size(1)
43-
length_new = int(length_orig * 1.2)
44-
45-
result = transforms.PadTrim(max_len=length_new)(waveform)
46-
self.assertEqual(result.size(1), length_new)
47-
48-
length_new = int(length_orig * 0.8)
49-
50-
result = transforms.PadTrim(max_len=length_new)(waveform)
51-
self.assertEqual(result.size(1), length_new)
52-
5339
def test_mu_law_companding(self):
5440

5541
quantization_channels = 256

torchaudio/functional.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44

55
__all__ = [
6-
'pad_trim',
76
'istft',
87
'spectrogram',
98
'create_fb_matrix',
@@ -18,28 +17,6 @@
1817
]
1918

2019

21-
@torch.jit.script
22-
def pad_trim(waveform, max_len, fill_value):
23-
# type: (Tensor, int, float) -> Tensor
24-
r"""Pad/trim a 2D tensor
25-
26-
Args:
27-
waveform (torch.Tensor): Tensor of audio of size (c, n)
28-
max_len (int): Length to which the waveform will be padded
29-
fill_value (float): Value to fill in
30-
31-
Returns:
32-
torch.Tensor: Padded/trimmed tensor
33-
"""
34-
n = waveform.size(1)
35-
if max_len > n:
36-
# TODO add "with torch.no_grad():" back when JIT supports it
37-
waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value)
38-
else:
39-
waveform = waveform[:, :max_len]
40-
return waveform
41-
42-
4320
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
4421
@torch.jit.ignore
4522
def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):

torchaudio/transforms.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,6 @@
77
from .compliance import kaldi
88

99

10-
class PadTrim(torch.jit.ScriptModule):
11-
r"""Pad/Trim a 2D tensor
12-
13-
Args:
14-
max_len (int): Length to which the waveform will be padded
15-
fill_value (float): Value to fill in
16-
"""
17-
__constants__ = ['max_len', 'fill_value']
18-
19-
def __init__(self, max_len, fill_value=0.):
20-
super(PadTrim, self).__init__()
21-
self.max_len = max_len
22-
self.fill_value = fill_value
23-
24-
@torch.jit.script_method
25-
def forward(self, waveform):
26-
r"""
27-
Args:
28-
waveform (torch.Tensor): Tensor of audio of size (c, n)
29-
30-
Returns:
31-
Tensor: Tensor of size (c, `max_len`)
32-
"""
33-
return F.pad_trim(waveform, self.max_len, self.fill_value)
34-
35-
3610
class Spectrogram(torch.jit.ScriptModule):
3711
r"""Create a spectrogram from a audio signal
3812

0 commit comments

Comments
 (0)