diff --git a/test/test_functional.py b/test/test_functional.py index 90d9c77321..4e36b4dce2 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -16,6 +16,15 @@ import librosa +def _test_torchscript_functional(py_method, *args, **kwargs): + jit_method = torch.jit.script(py_method) + + jit_out = jit_method(*args, **kwargs) + py_out = py_method(*args, **kwargs) + + assert torch.allclose(jit_out, py_out) + + class TestFunctional(unittest.TestCase): data_sizes = [(2, 20), (3, 15), (4, 10)] number_of_trials = 100 @@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase): test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') + def test_torchscript_spectrogram(self): + + tensor = torch.rand((1, 1000)) + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws) + power = 2 + normalize = False + + _test_torchscript_functional( + F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize + ) + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) @@ -49,6 +73,7 @@ def test_compute_deltas_randn(self): specgram = torch.randn(channel, n_mfcc, time) computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) def test_batch_pitch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) @@ -63,6 +88,7 @@ def test_batch_pitch(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original @@ -424,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5) + def test_torchscript_create_fb_matrix(self): + + n_stft = 100 + f_min = 0.0 + f_max = 20.0 + n_mels = 10 + sample_rate = 16000 + + _test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate) + + def test_torchscript_amplitude_to_DB(self): + + spec = torch.rand((6, 201)) + multiplier = 10.0 + amin = 1e-10 + db_multiplier = 0.0 + top_db = 80.0 + + _test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db) + + def test_torchscript_create_dct(self): + + n_mfcc = 40 + n_mels = 128 + norm = "ortho" + + _test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm) + + def test_torchscript_mu_law_encoding(self): + + tensor = torch.rand((1, 10)) + qc = 256 + + _test_torchscript_functional(F.mu_law_encoding, tensor, qc) + + def test_torchscript_mu_law_decoding(self): + + tensor = torch.rand((1, 10)) + qc = 256 + + _test_torchscript_functional(F.mu_law_decoding, tensor, qc) + + def test_torchscript_complex_norm(self): + + complex_tensor = torch.randn(1, 2, 1025, 400, 2), + power = 2 + + _test_torchscript_functional(F.complex_norm, complex_tensor, power) + + def test_mask_along_axis(self): + + specgram = torch.randn(2, 1025, 400), + mask_param = 100 + mask_value = 30. + axis = 2 + + _test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis) + + def test_mask_along_axis_iid(self): + + specgram = torch.randn(2, 1025, 400), + specgrams = torch.randn(4, 2, 1025, 400), + mask_param = 100 + mask_value = 30. + axis = 2 + + _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) + @pytest.mark.parametrize('complex_tensor', [ torch.randn(1, 2, 1025, 400, 2), diff --git a/test/test_functional_filtering.py b/test/test_functional_filtering.py index ab209ed7a7..b90c41278b 100644 --- a/test/test_functional_filtering.py +++ b/test/test_functional_filtering.py @@ -9,6 +9,15 @@ import time +def _test_torchscript_functional(py_method, *args, **kwargs): + jit_method = torch.jit.script(py_method) + + jit_out = jit_method(*args, **kwargs) + py_out = py_method(*args, **kwargs) + + assert torch.allclose(jit_out, py_out) + + class TestFunctionalFiltering(unittest.TestCase): test_dirpath, test_dir = common_utils.create_temp_assets_dir() @@ -79,6 +88,7 @@ def _test_lfilter(self, waveform, device): assert len(output_waveform.size()) == 2 assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(1) == waveform.size(1) + _test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs) def test_lfilter(self): @@ -116,6 +126,7 @@ def test_lowpass(self): output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) + _test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ) def test_highpass(self): """ @@ -135,6 +146,7 @@ def test_highpass(self): # TBD - this fails at the 1e-4 level, debug why assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) + _test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ) def test_equalizer(self): """ @@ -155,6 +167,7 @@ def test_equalizer(self): output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) + _test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q) def test_perf_biquad_filtering(self): @@ -183,6 +196,9 @@ def test_perf_biquad_filtering(self): _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) + _test_torchscript_functional( + F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) + ) if __name__ == "__main__": diff --git a/test/test_jit.py b/test/test_jit.py deleted file mode 100644 index 8cb1df344d..0000000000 --- a/test/test_jit.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals -import torch -import torchaudio.functional as F -import torchaudio.transforms as transforms -import unittest - -RUN_CUDA = torch.cuda.is_available() -print('Run test with cuda:', RUN_CUDA) - - -class Test_JIT(unittest.TestCase): - def _get_script_module(self, f, *args): - # takes a transform function `f` and wraps it in a script module - class MyModule(torch.jit.ScriptModule): - def __init__(self): - super(MyModule, self).__init__() - self.module = f(*args) - self.module.eval() - - @torch.jit.script_method - def forward(self, tensor): - return self.module(tensor) - - return MyModule() - - def _test_script_module(self, tensor, f, *args): - # tests a script module that wraps a transform function `f` by feeding - # the tensor into the forward function - jit_out = self._get_script_module(f, *args).cuda()(tensor) - py_out = f(*args).cuda()(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - def test_torchscript_spectrogram(self): - @torch.jit.script - def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): - # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor - return F.spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize) - - tensor = torch.rand((1, 1000)) - n_fft = 400 - ws = 400 - hop = 200 - pad = 0 - window = torch.hann_window(ws) - power = 2 - normalize = False - - jit_out = jit_method(tensor, pad, window, n_fft, hop, ws, power, normalize) - py_out = F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_Spectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.Spectrogram) - - def test_torchscript_create_fb_matrix(self): - @torch.jit.script - def jit_method(n_stft, f_min, f_max, n_mels, sample_rate): - # type: (int, float, float, int, int) -> Tensor - return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate) - - n_stft = 100 - f_min = 0. - f_max = 20. - n_mels = 10 - sample_rate = 16000 - - jit_out = jit_method(n_stft, f_min, f_max, n_mels, sample_rate) - py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MelScale(self): - spec_f = torch.rand((1, 6, 201), device="cuda") - - self._test_script_module(spec_f, transforms.MelScale) - - def test_torchscript_amplitude_to_DB(self): - @torch.jit.script - def jit_method(spec, multiplier, amin, db_multiplier, top_db): - # type: (Tensor, float, float, float, Optional[float]) -> Tensor - return F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) - - spec = torch.rand((6, 201)) - multiplier = 10. - amin = 1e-10 - db_multiplier = 0. - top_db = 80. - - jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db) - py_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_AmplitudeToDB(self): - spec = torch.rand((6, 201), device="cuda") - - self._test_script_module(spec, transforms.AmplitudeToDB) - - def test_torchscript_create_dct(self): - @torch.jit.script - def jit_method(n_mfcc, n_mels, norm): - # type: (int, int, Optional[str]) -> Tensor - return F.create_dct(n_mfcc, n_mels, norm) - - n_mfcc = 40 - n_mels = 128 - norm = 'ortho' - - jit_out = jit_method(n_mfcc, n_mels, norm) - py_out = F.create_dct(n_mfcc, n_mels, norm) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MFCC(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MFCC) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MelSpectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MelSpectrogram) - - def test_torchscript_mu_law_encoding(self): - @torch.jit.script - def jit_method(tensor, qc): - # type: (Tensor, int) -> Tensor - return F.mu_law_encoding(tensor, qc) - - tensor = torch.rand((1, 10)) - qc = 256 - - jit_out = jit_method(tensor, qc) - py_out = F.mu_law_encoding(tensor, qc) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MuLawEncoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawEncoding) - - def test_torchscript_mu_law_decoding(self): - @torch.jit.script - def jit_method(tensor, qc): - # type: (Tensor, int) -> Tensor - return F.mu_law_decoding(tensor, qc) - - tensor = torch.rand((1, 10)) - qc = 256 - - jit_out = jit_method(tensor, qc) - py_out = F.mu_law_decoding(tensor, qc) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MuLawDecoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawDecoding) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_transforms.py b/test/test_transforms.py index 82594bad01..bc82f0a673 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,6 +4,7 @@ import torch import torchaudio +import torchaudio.augmentations as A import torchaudio.transforms as transforms import torchaudio.functional as F from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY @@ -16,6 +17,32 @@ if IMPORT_SCIPY: import scipy +RUN_CUDA = torch.cuda.is_available() +print("Run test with cuda:", RUN_CUDA) + + +def _test_script_module(f, tensor, *args, **kwargs): + + py_method = f(*args, **kwargs) + jit_method = torch.jit.script(py_method) + + py_out = py_method(tensor) + jit_out = jit_method(tensor) + + assert torch.allclose(jit_out, py_out) + + if RUN_CUDA: + + tensor = tensor.to("cuda") + + py_method = py_method.cuda() + jit_method = torch.jit.script(py_method) + + py_out = py_method(tensor) + jit_out = jit_method(tensor) + + assert torch.allclose(jit_out, py_out) + class Tester(unittest.TestCase): @@ -37,6 +64,10 @@ def scale(self, waveform, factor=float(2**31)): waveform = waveform.to(torch.get_default_dtype()) return waveform / factor + def test_scriptmodule_Spectrogram(self): + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.Spectrogram, tensor) + def test_mu_law_companding(self): quantization_channels = 256 @@ -51,6 +82,14 @@ def test_mu_law_companding(self): waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) + def test_scriptmodule_AmplitudeToDB(self): + spec = torch.rand((6, 201)) + _test_script_module(transforms.AmplitudeToDB, spec) + + def test_scriptmodule_MelScale(self): + spec_f = torch.rand((1, 6, 201)) + _test_script_module(transforms.MelScale, spec_f) + def test_melscale_load_save(self): specgram = torch.ones(1, 1000, 100) melscale_transform = transforms.MelScale() @@ -65,6 +104,10 @@ def test_melscale_load_save(self): self.assertEqual(fb_copy.size(), (1000, 128)) self.assertTrue(torch.allclose(fb, fb_copy)) + def test_scriptmodule_MelSpectrogram(self): + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.MelSpectrogram, tensor) + def test_melspectrogram_load_save(self): waveform = self.waveform.float() mel_spectrogram_transform = transforms.MelSpectrogram() @@ -123,6 +166,10 @@ def test_mel2(self): self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) + def test_scriptmodule_MFCC(self): + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.MFCC, tensor) + def test_mfcc(self): audio_orig = self.waveform.clone() audio_scaled = self.scale(audio_orig) # (1, 16000) @@ -326,6 +373,14 @@ def test_batch_compute_deltas(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_scriptmodule_MuLawEncoding(self): + tensor = torch.rand((1, 10)) + _test_script_module(transforms.MuLawEncoding, tensor) + + def test_scriptmodule_MuLawDecoding(self): + tensor = torch.rand((1, 10)) + _test_script_module(transforms.MuLawDecoding, tensor) + def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 @@ -364,6 +419,21 @@ def test_batch_spectrogram(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_scriptmodule_TimeStretch(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.rand((10, 2, n_freq, 10, 2)) + _test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) + + def test_scriptmodule_FrequencyMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + _test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) + + def test_scriptmodule_TimeMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + _test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index efa07e9fd8..ed1573bdac 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -24,14 +24,13 @@ class TimeStretch(torch.jit.ScriptModule): def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): super(TimeStretch, self).__init__() + self.fixed_rate = fixed_rate + n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 - self.fixed_rate = fixed_rate phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] - self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) - @torch.jit.script_method def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor r""" @@ -63,7 +62,7 @@ def forward(self, complex_specgrams, overriding_rate=None): return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) -class _AxisMasking(torch.jit.ScriptModule): +class _AxisMasking(torch.nn.Module): r""" Apply masking to a spectrogram. Args: @@ -80,7 +79,6 @@ def __init__(self, mask_param, axis, iid_masks): self.axis = axis self.iid_masks = iid_masks - @torch.jit.script_method def forward(self, specgram, mask_value=0.): # type: (Tensor, float) -> Tensor r""" diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9bf33f4feb..b06910f7c9 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -221,7 +221,6 @@ def istft( return y -@torch.jit.script def spectrogram( waveform, pad, window, n_fft, hop_length, win_length, power, normalized ): @@ -274,7 +273,6 @@ def spectrogram( return spec_f -@torch.jit.script def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor r""" @@ -309,7 +307,6 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): return x_db -@torch.jit.script def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): # type: (int, float, float, int, int) -> Tensor r""" @@ -355,7 +352,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): return fb -@torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor r""" @@ -386,7 +382,6 @@ def create_dct(n_mfcc, n_mels, norm): return dct.t() -@torch.jit.script def mu_law_encoding(x, quantization_channels): # type: (Tensor, int) -> Tensor r""" @@ -414,7 +409,6 @@ def mu_law_encoding(x, quantization_channels): return x_mu -@torch.jit.script def mu_law_decoding(x_mu, quantization_channels): # type: (Tensor, int) -> Tensor r""" @@ -442,7 +436,6 @@ def mu_law_decoding(x_mu, quantization_channels): return x -@torch.jit.script def complex_norm(complex_tensor, power=1.0): # type: (Tensor, float) -> Tensor r"""Compute the norm of complex tensor input. @@ -459,7 +452,6 @@ def complex_norm(complex_tensor, power=1.0): return torch.norm(complex_tensor, 2, -1).pow(power) -@torch.jit.script def angle(complex_tensor): # type: (Tensor) -> Tensor r"""Compute the angle of complex tensor input. @@ -473,7 +465,6 @@ def angle(complex_tensor): return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) -@torch.jit.script def magphase(complex_tensor, power=1.0): # type: (Tensor, float) -> Tuple[Tensor, Tensor] r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. @@ -490,7 +481,6 @@ def magphase(complex_tensor, power=1.0): return mag, phase -@torch.jit.script def phase_vocoder(complex_specgrams, rate, phase_advance): # type: (Tensor, float, Tensor) -> Tensor r"""Given a STFT tensor, speed up in time without modifying pitch by a @@ -555,7 +545,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): return complex_specgrams_stretch -@torch.jit.script def lfilter(waveform, a_coeffs, b_coeffs): # type: (Tensor, Tensor, Tensor) -> Tensor r""" @@ -630,7 +619,6 @@ def lfilter(waveform, a_coeffs, b_coeffs): return output -@torch.jit.script def biquad(waveform, b0, b1, b2, a0, a1, a2): # type: (Tensor, float, float, float, float, float, float) -> Tensor r"""Performs a biquad filter of input tensor. Initial conditions set to 0. @@ -665,7 +653,6 @@ def _dB2Linear(x): return math.exp(x * math.log(10) / 20.0) -@torch.jit.script def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): # type: (Tensor, int, float, float) -> Tensor r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation. @@ -695,7 +682,6 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): # type: (Tensor, int, float, float) -> Tensor r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation. @@ -725,7 +711,6 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): # type: (Tensor, int, float, float, float) -> Tensor r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation. @@ -753,7 +738,6 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" @@ -790,7 +774,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): return specgrams -@torch.jit.script def mask_along_axis(specgram, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" @@ -825,7 +808,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): return specgram -@torch.jit.script def compute_deltas(specgram, win_length=5, mode="replicate"): # type: (Tensor, int, str) -> Tensor r"""Compute delta coefficients of a tensor, usually a spectrogram: @@ -878,7 +860,6 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): return output -@torch.jit.script def _compute_nccf(waveform, sample_rate, frame_time, freq_low): # type: (Tensor, int, float, int) -> Tensor r""" @@ -993,7 +974,6 @@ def _median_smoothing(indices, win_length): return values -@torch.jit.script def detect_pitch_frequency( waveform, sample_rate, @@ -1021,7 +1001,7 @@ def detect_pitch_frequency( dim = waveform.dim() # pack batch - shape = waveform.size() + shape = list(waveform.size()) waveform = waveform.reshape([-1] + shape[-1:]) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) @@ -1033,6 +1013,6 @@ def detect_pitch_frequency( freq = sample_rate / (EPSILON + indices.to(torch.float)) # unpack batch - freq = freq.reshape(shape[:-1] + freq.shape[-1:]) + freq = freq.reshape(shape[:-1] + list(freq.shape[-1:])) return freq diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d65c7275a3..c7e7bfa0ba 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -20,7 +20,7 @@ ] -class Spectrogram(torch.jit.ScriptModule): +class Spectrogram(torch.nn.Module): r"""Create a spectrogram from a audio signal Args: @@ -48,12 +48,11 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.window = torch.jit.Attribute(window, torch.Tensor) + self.window = window self.pad = pad self.power = power self.normalized = normalized - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -85,7 +84,7 @@ class AmplitudeToDB(torch.jit.ScriptModule): def __init__(self, stype='power', top_db=None): super(AmplitudeToDB, self).__init__() - self.stype = torch.jit.Attribute(stype, str) + self.stype = stype if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) @@ -94,7 +93,6 @@ def __init__(self, stype='power', top_db=None): self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value)) - @torch.jit.script_method def forward(self, x): r"""Numerically stable implementation from Librosa https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html @@ -108,7 +106,7 @@ def forward(self, x): return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) -class MelScale(torch.jit.ScriptModule): +class MelScale(torch.nn.Module): r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. @@ -129,13 +127,14 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) - assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) self.f_min = f_min + + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) - self.fb = torch.jit.Attribute(fb, torch.Tensor) + self.fb = fb - @torch.jit.script_method def forward(self, specgram): r""" Args: @@ -156,7 +155,7 @@ def forward(self, specgram): return mel_specgram -class MelSpectrogram(torch.jit.ScriptModule): +class MelSpectrogram(torch.nn.Module): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. @@ -194,7 +193,7 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.pad = pad self.n_mels = n_mels # number of mel frequency bins - self.f_max = torch.jit.Attribute(f_max, Optional[float]) + self.f_max = f_max self.f_min = f_min self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, @@ -202,7 +201,6 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non normalized=False, wkwargs=wkwargs) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -216,7 +214,7 @@ def forward(self, waveform): return mel_specgram -class MFCC(torch.jit.ScriptModule): +class MFCC(torch.nn.Module): r"""Create the Mel-frequency cepstrum coefficients from an audio signal By default, this calculates the MFCC on the DB-scaled Mel spectrogram. @@ -247,7 +245,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m self.sample_rate = sample_rate self.n_mfcc = n_mfcc self.dct_type = dct_type - self.norm = torch.jit.Attribute(norm, Optional[str]) + self.norm = norm self.top_db = 80.0 self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) @@ -259,10 +257,9 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError('Cannot select more MFCC coefficients than # mel bins') dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) - self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) + self.dct_mat = dct_mat self.log_mels = log_mels - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -283,7 +280,7 @@ def forward(self, waveform): return mfcc -class MuLawEncoding(torch.jit.ScriptModule): +class MuLawEncoding(torch.nn.Module): r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ @@ -299,7 +296,6 @@ def __init__(self, quantization_channels=256): super(MuLawEncoding, self).__init__() self.quantization_channels = quantization_channels - @torch.jit.script_method def forward(self, x): r""" Args: @@ -311,7 +307,7 @@ def forward(self, x): return F.mu_law_encoding(x, self.quantization_channels) -class MuLawDecoding(torch.jit.ScriptModule): +class MuLawDecoding(torch.nn.Module): r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ @@ -327,7 +323,6 @@ def __init__(self, quantization_channels=256): super(MuLawDecoding, self).__init__() self.quantization_channels = quantization_channels - @torch.jit.script_method def forward(self, x_mu): r""" Args: @@ -368,7 +363,7 @@ def forward(self, waveform): raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) -class ComplexNorm(torch.jit.ScriptModule): +class ComplexNorm(torch.nn.Module): r"""Compute the norm of complex tensor input Args: power (float): Power of the norm. Defaults to `1.0`. @@ -379,7 +374,6 @@ def __init__(self, power=1.0): super(ComplexNorm, self).__init__() self.power = power - @torch.jit.script_method def forward(self, complex_tensor): r""" Args: @@ -390,7 +384,7 @@ def forward(self, complex_tensor): return F.complex_norm(complex_tensor, self.power) -class ComputeDeltas(torch.jit.ScriptModule): +class ComputeDeltas(torch.nn.Module): r"""Compute delta coefficients of a tensor, usually a spectrogram. See `torchaudio.functional.compute_deltas` for more details. @@ -403,9 +397,8 @@ class ComputeDeltas(torch.jit.ScriptModule): def __init__(self, win_length=5, mode="replicate"): super(ComputeDeltas, self).__init__() self.win_length = win_length - self.mode = torch.jit.Attribute(mode, str) + self.mode = mode - @torch.jit.script_method def forward(self, specgram): r""" Args: