From 402cef711756d31a9375e520b3031e58ec5b2945 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 18:19:23 -0500 Subject: [PATCH 01/16] nn.Module. --- torchaudio/functional.py | 12 ------------ torchaudio/transforms.py | 27 +++++++++------------------ 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9bf33f4feb..387d840b75 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -221,7 +221,6 @@ def istft( return y -@torch.jit.script def spectrogram( waveform, pad, window, n_fft, hop_length, win_length, power, normalized ): @@ -274,7 +273,6 @@ def spectrogram( return spec_f -@torch.jit.script def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor r""" @@ -309,7 +307,6 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): return x_db -@torch.jit.script def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): # type: (int, float, float, int, int) -> Tensor r""" @@ -355,7 +352,6 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): return fb -@torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor r""" @@ -386,7 +382,6 @@ def create_dct(n_mfcc, n_mels, norm): return dct.t() -@torch.jit.script def mu_law_encoding(x, quantization_channels): # type: (Tensor, int) -> Tensor r""" @@ -414,7 +409,6 @@ def mu_law_encoding(x, quantization_channels): return x_mu -@torch.jit.script def mu_law_decoding(x_mu, quantization_channels): # type: (Tensor, int) -> Tensor r""" @@ -442,7 +436,6 @@ def mu_law_decoding(x_mu, quantization_channels): return x -@torch.jit.script def complex_norm(complex_tensor, power=1.0): # type: (Tensor, float) -> Tensor r"""Compute the norm of complex tensor input. @@ -490,7 +483,6 @@ def magphase(complex_tensor, power=1.0): return mag, phase -@torch.jit.script def phase_vocoder(complex_specgrams, rate, phase_advance): # type: (Tensor, float, Tensor) -> Tensor r"""Given a STFT tensor, speed up in time without modifying pitch by a @@ -725,7 +717,6 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): # type: (Tensor, int, float, float, float) -> Tensor r"""Designs biquad peaking equalizer filter and performs filtering. Similar to SoX implementation. @@ -753,7 +744,6 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" @@ -790,7 +780,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): return specgrams -@torch.jit.script def mask_along_axis(specgram, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" @@ -825,7 +814,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): return specgram -@torch.jit.script def compute_deltas(specgram, win_length=5, mode="replicate"): # type: (Tensor, int, str) -> Tensor r"""Compute delta coefficients of a tensor, usually a spectrogram: diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d65c7275a3..1b361ce526 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -20,7 +20,7 @@ ] -class Spectrogram(torch.jit.ScriptModule): +class Spectrogram(torch.nn.Module): r"""Create a spectrogram from a audio signal Args: @@ -53,7 +53,6 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, self.power = power self.normalized = normalized - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -68,7 +67,7 @@ def forward(self, waveform): self.win_length, self.power, self.normalized) -class AmplitudeToDB(torch.jit.ScriptModule): +class AmplitudeToDB(torch.nn.Module): r"""Turns a tensor from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input tensor, and so @@ -94,7 +93,6 @@ def __init__(self, stype='power', top_db=None): self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value)) - @torch.jit.script_method def forward(self, x): r"""Numerically stable implementation from Librosa https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html @@ -108,7 +106,7 @@ def forward(self, x): return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) -class MelScale(torch.jit.ScriptModule): +class MelScale(torch.nn.Module): r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. @@ -135,7 +133,6 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) self.fb = torch.jit.Attribute(fb, torch.Tensor) - @torch.jit.script_method def forward(self, specgram): r""" Args: @@ -156,7 +153,7 @@ def forward(self, specgram): return mel_specgram -class MelSpectrogram(torch.jit.ScriptModule): +class MelSpectrogram(torch.nn.Module): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. @@ -202,7 +199,6 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non normalized=False, wkwargs=wkwargs) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -216,7 +212,7 @@ def forward(self, waveform): return mel_specgram -class MFCC(torch.jit.ScriptModule): +class MFCC(torch.nn.Module): r"""Create the Mel-frequency cepstrum coefficients from an audio signal By default, this calculates the MFCC on the DB-scaled Mel spectrogram. @@ -262,7 +258,6 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) self.log_mels = log_mels - @torch.jit.script_method def forward(self, waveform): r""" Args: @@ -283,7 +278,7 @@ def forward(self, waveform): return mfcc -class MuLawEncoding(torch.jit.ScriptModule): +class MuLawEncoding(torch.nn.Module): r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ @@ -299,7 +294,6 @@ def __init__(self, quantization_channels=256): super(MuLawEncoding, self).__init__() self.quantization_channels = quantization_channels - @torch.jit.script_method def forward(self, x): r""" Args: @@ -311,7 +305,7 @@ def forward(self, x): return F.mu_law_encoding(x, self.quantization_channels) -class MuLawDecoding(torch.jit.ScriptModule): +class MuLawDecoding(torch.nn.Module): r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ @@ -327,7 +321,6 @@ def __init__(self, quantization_channels=256): super(MuLawDecoding, self).__init__() self.quantization_channels = quantization_channels - @torch.jit.script_method def forward(self, x_mu): r""" Args: @@ -368,7 +361,7 @@ def forward(self, waveform): raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) -class ComplexNorm(torch.jit.ScriptModule): +class ComplexNorm(torch.nn.Module): r"""Compute the norm of complex tensor input Args: power (float): Power of the norm. Defaults to `1.0`. @@ -379,7 +372,6 @@ def __init__(self, power=1.0): super(ComplexNorm, self).__init__() self.power = power - @torch.jit.script_method def forward(self, complex_tensor): r""" Args: @@ -390,7 +382,7 @@ def forward(self, complex_tensor): return F.complex_norm(complex_tensor, self.power) -class ComputeDeltas(torch.jit.ScriptModule): +class ComputeDeltas(torch.nn.Module): r"""Compute delta coefficients of a tensor, usually a spectrogram. See `torchaudio.functional.compute_deltas` for more details. @@ -405,7 +397,6 @@ def __init__(self, win_length=5, mode="replicate"): self.win_length = win_length self.mode = torch.jit.Attribute(mode, str) - @torch.jit.script_method def forward(self, specgram): r""" Args: From cf5255b73359a90cc76230d98a1f887975cc214e Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 18:30:01 -0500 Subject: [PATCH 02/16] generalizing spectrogram test. --- test/test_jit.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 8cb1df344d..95739e85c1 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -31,11 +31,15 @@ def _test_script_module(self, tensor, f, *args): self.assertTrue(torch.allclose(jit_out, py_out)) + def _test_torchscript_functional(self, py_method, *args): + jit_out = torch.jit.script(py_method) + + jit_out = jit_method(*args) + py_out = py_method(*args) + + self.assertTrue(torch.allclose(jit_out, py_out)) + def test_torchscript_spectrogram(self): - @torch.jit.script - def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): - # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor - return F.spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize) tensor = torch.rand((1, 1000)) n_fft = 400 @@ -46,10 +50,7 @@ def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): power = 2 normalize = False - jit_out = jit_method(tensor, pad, window, n_fft, hop, ws, power, normalize) - py_out = F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) - - self.assertTrue(torch.allclose(jit_out, py_out)) + self._test_torscript_functional(tensor, pad, window, n_fft, hop, ws, power, normalize) @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_Spectrogram(self): From 528e1ab91886c9ad6b9cade3cdebef02bed76fa5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 11:50:58 -0500 Subject: [PATCH 03/16] adding test to compile functionals. --- test/test_functional.py | 26 ++++++++++++++++++++++++++ test/test_functional_filtering.py | 16 ++++++++++++++++ test/test_jit.py | 21 --------------------- torchaudio/functional.py | 5 ++--- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 90d9c77321..eb60031851 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -16,6 +16,15 @@ import librosa +def _test_torchscript_functional(py_method, *args, **kwargs): + jit_method = torch.jit.script(py_method) + + jit_out = jit_method(*args, **kwargs) + py_out = py_method(*args, **kwargs) + + assert torch.allclose(jit_out, py_out) + + class TestFunctional(unittest.TestCase): data_sizes = [(2, 20), (3, 15), (4, 10)] number_of_trials = 100 @@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase): test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') + def test_torchscript_spectrogram(self): + + tensor = torch.rand((1, 1000)) + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws) + power = 2 + normalize = False + + _test_torchscript_functional( + F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize + ) + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) @@ -49,6 +73,7 @@ def test_compute_deltas_randn(self): specgram = torch.randn(channel, n_mfcc, time) computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) def test_batch_pitch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) @@ -63,6 +88,7 @@ def test_batch_pitch(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original diff --git a/test/test_functional_filtering.py b/test/test_functional_filtering.py index ab209ed7a7..e58894cb63 100644 --- a/test/test_functional_filtering.py +++ b/test/test_functional_filtering.py @@ -9,6 +9,15 @@ import time +def _test_torchscript_functional(py_method, *args, **kwargs): + jit_method = torch.jit.script(py_method) + + jit_out = jit_method(*args, **kwargs) + py_out = py_method(*args, **kwargs) + + assert torch.allclose(jit_out, py_out) + + class TestFunctionalFiltering(unittest.TestCase): test_dirpath, test_dir = common_utils.create_temp_assets_dir() @@ -79,6 +88,7 @@ def _test_lfilter(self, waveform, device): assert len(output_waveform.size()) == 2 assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(1) == waveform.size(1) + _test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs) def test_lfilter(self): @@ -116,6 +126,7 @@ def test_lowpass(self): output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) + _test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ) def test_highpass(self): """ @@ -135,6 +146,7 @@ def test_highpass(self): # TBD - this fails at the 1e-4 level, debug why assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) + _test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ) def test_equalizer(self): """ @@ -155,6 +167,7 @@ def test_equalizer(self): output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) + _test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q) def test_perf_biquad_filtering(self): @@ -183,6 +196,9 @@ def test_perf_biquad_filtering(self): _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) + _test_torchscript_functional( + F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) + ) if __name__ == "__main__": diff --git a/test/test_jit.py b/test/test_jit.py index 95739e85c1..4d0911c955 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -31,27 +31,6 @@ def _test_script_module(self, tensor, f, *args): self.assertTrue(torch.allclose(jit_out, py_out)) - def _test_torchscript_functional(self, py_method, *args): - jit_out = torch.jit.script(py_method) - - jit_out = jit_method(*args) - py_out = py_method(*args) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - def test_torchscript_spectrogram(self): - - tensor = torch.rand((1, 1000)) - n_fft = 400 - ws = 400 - hop = 200 - pad = 0 - window = torch.hann_window(ws) - power = 2 - normalize = False - - self._test_torscript_functional(tensor, pad, window, n_fft, hop, ws, power, normalize) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_Spectrogram(self): tensor = torch.rand((1, 1000), device="cuda") diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 387d840b75..e2d3759f42 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -981,7 +981,6 @@ def _median_smoothing(indices, win_length): return values -@torch.jit.script def detect_pitch_frequency( waveform, sample_rate, @@ -1009,7 +1008,7 @@ def detect_pitch_frequency( dim = waveform.dim() # pack batch - shape = waveform.size() + shape = list(waveform.size()) waveform = waveform.reshape([-1] + shape[-1:]) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) @@ -1021,6 +1020,6 @@ def detect_pitch_frequency( freq = sample_rate / (EPSILON + indices.to(torch.float)) # unpack batch - freq = freq.reshape(shape[:-1] + freq.shape[-1:]) + freq = freq.reshape(shape[:-1] + list(freq.shape[-1:])) return freq From c626790924ad44cbf640d4edf9eb74df1cbb8bc5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 13:19:26 -0500 Subject: [PATCH 04/16] add cuda/cpu compiolation test. --- test/test_functional.py | 68 +++++++++++++++++++++++++++++++++ test/test_transforms.py | 83 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+) diff --git a/test/test_functional.py b/test/test_functional.py index eb60031851..b8de12cd8c 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -450,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5) + def test_torchscript_create_fb_matrix(self): + + n_stft = 100 + f_min = 0.0 + f_max = 20.0 + n_mels = 10 + sample_rate = 16000 + + _test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate) + + def test_torchscript_amplitude_to_DB(self): + + spec = torch.rand((6, 201)) + multiplier = 10.0 + amin = 1e-10 + db_multiplier = 0.0 + top_db = 80.0 + + _test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db) + + def test_torchscript_create_dct(self): + + n_mfcc = 40 + n_mels = 128 + norm = "ortho" + + _test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm) + + def test_torchscript_mu_law_encoding(self): + + tensor = torch.rand((1, 10)) + qc = 256 + + _test_torchscript_functional(F.mu_law_encoding, tensor, qc) + + def test_torchscript_mu_law_decoding(self): + + tensor = torch.rand((1, 10)) + qc = 256 + + _test_torchscript_functional(F.mu_law_decoding, tensor, qc) + + def test_torchscript_mu_law_decoding(self): + + complex_tensor = torch.randn(1, 2, 1025, 400, 2), + power = 2 + + _test_torchscript_functional(F.complex_norm, complex_tensor, power) + + def test_mask_along_axis(self): + + specgram = torch.randn(2, 1025, 400), + mask_param = 100 + mask_value = 30. + axis = 2 + + _test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis) + + def test_mask_along_axis_iid(self): + + specgram = torch.randn(2, 1025, 400), + specgrams = torch.randn(4, 2, 1025, 400), + mask_param = 100 + mask_value = 30. + axis = 2 + + _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) + @pytest.mark.parametrize('complex_tensor', [ torch.randn(1, 2, 1025, 400, 2), diff --git a/test/test_transforms.py b/test/test_transforms.py index 82594bad01..4987ae7579 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -16,6 +16,48 @@ if IMPORT_SCIPY: import scipy +RUN_CUDA = torch.cuda.is_available() +print("Run test with cuda:", RUN_CUDA) + + +def _get_script_module(self, f, *args): + # takes a transform function `f` and wraps it in a script module + class MyModule(torch.jit.ScriptModule): + def __init__(self): + super(MyModule, self).__init__() + self.module = f(*args) + self.module.eval() + + @torch.jit.script_method + def forward(self, tensor): + return self.module(tensor) + + return MyModule() + + +def _test_script_module(py_method, tensor, f, *args): + # tests a script module that wraps a transform function `f` by feeding + # the tensor into the forward function + jit_method = _get_script_module(f, *args) + py_method = f(*args) + + jit_out = jit_method(tensor) + py_out = py_method(tensor) + + self.assertTrue(torch.allclose(jit_out, py_out)) + + if RUN_CUDA: + + tensor = tensor.to("cuda") + + jit_method = _get_script_module(f, *args).cuda() + py_method = f(*args).cuda() + + jit_out = jit_method(tensor) + py_out = py_method(tensor) + + self.assertTrue(torch.allclose(jit_out, py_out)) + class Tester(unittest.TestCase): @@ -37,6 +79,11 @@ def scale(self, waveform, factor=float(2**31)): waveform = waveform.to(torch.get_default_dtype()) return waveform / factor + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_Spectrogram(self): + tensor = torch.rand((1, 1000), device="cuda") + self._test_script_module(tensor, transforms.Spectrogram) + def test_mu_law_companding(self): quantization_channels = 256 @@ -51,6 +98,18 @@ def test_mu_law_companding(self): waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_AmplitudeToDB(self): + spec = torch.rand((6, 201), device="cuda") + + self._test_script_module(spec, transforms.AmplitudeToDB) + + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_MelScale(self): + spec_f = torch.rand((1, 6, 201), device="cuda") + + self._test_script_module(spec_f, transforms.MelScale) + def test_melscale_load_save(self): specgram = torch.ones(1, 1000, 100) melscale_transform = transforms.MelScale() @@ -65,6 +124,12 @@ def test_melscale_load_save(self): self.assertEqual(fb_copy.size(), (1000, 128)) self.assertTrue(torch.allclose(fb, fb_copy)) + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_MelSpectrogram(self): + tensor = torch.rand((1, 1000), device="cuda") + + self._test_script_module(tensor, transforms.MelSpectrogram) + def test_melspectrogram_load_save(self): waveform = self.waveform.float() mel_spectrogram_transform = transforms.MelSpectrogram() @@ -123,6 +188,12 @@ def test_mel2(self): self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_MFCC(self): + tensor = torch.rand((1, 1000), device="cuda") + + self._test_script_module(tensor, transforms.MFCC) + def test_mfcc(self): audio_orig = self.waveform.clone() audio_scaled = self.scale(audio_orig) # (1, 16000) @@ -326,6 +397,18 @@ def test_batch_compute_deltas(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_MuLawEncoding(self): + tensor = torch.rand((1, 10), device="cuda") + + self._test_script_module(tensor, transforms.MuLawEncoding) + + @unittest.skipIf(not RUN_CUDA, "no CUDA") + def test_scriptmodule_MuLawDecoding(self): + tensor = torch.rand((1, 10), device="cuda") + + self._test_script_module(tensor, transforms.MuLawDecoding) + def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 From 80cd3df1d63d5c3bbe12383f57fcb3b9b70afb2d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 13:27:02 -0500 Subject: [PATCH 05/16] adding transform test. --- test/test_transforms.py | 50 ++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4987ae7579..c660a0c502 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -20,7 +20,7 @@ print("Run test with cuda:", RUN_CUDA) -def _get_script_module(self, f, *args): +def _get_script_module(f, *args): # takes a transform function `f` and wraps it in a script module class MyModule(torch.jit.ScriptModule): def __init__(self): @@ -35,16 +35,17 @@ def forward(self, tensor): return MyModule() -def _test_script_module(py_method, tensor, f, *args): +def _test_script_module(f, tensor, *args): # tests a script module that wraps a transform function `f` by feeding # the tensor into the forward function + jit_method = _get_script_module(f, *args) py_method = f(*args) jit_out = jit_method(tensor) py_out = py_method(tensor) - self.assertTrue(torch.allclose(jit_out, py_out)) + assert torch.allclose(jit_out, py_out) if RUN_CUDA: @@ -56,7 +57,7 @@ def _test_script_module(py_method, tensor, f, *args): jit_out = jit_method(tensor) py_out = py_method(tensor) - self.assertTrue(torch.allclose(jit_out, py_out)) + assert torch.allclose(jit_out, py_out) class Tester(unittest.TestCase): @@ -79,10 +80,9 @@ def scale(self, waveform, factor=float(2**31)): waveform = waveform.to(torch.get_default_dtype()) return waveform / factor - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_Spectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - self._test_script_module(tensor, transforms.Spectrogram) + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.Spectrogram, tensor) def test_mu_law_companding(self): @@ -98,17 +98,13 @@ def test_mu_law_companding(self): waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_AmplitudeToDB(self): - spec = torch.rand((6, 201), device="cuda") - - self._test_script_module(spec, transforms.AmplitudeToDB) + spec = torch.rand((6, 201)) + _test_script_module(transforms.AmplitudeToDB, spec) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MelScale(self): - spec_f = torch.rand((1, 6, 201), device="cuda") - - self._test_script_module(spec_f, transforms.MelScale) + spec_f = torch.rand((1, 6, 201)) + _test_script_module(transforms.MelScale, spec_f) def test_melscale_load_save(self): specgram = torch.ones(1, 1000, 100) @@ -124,11 +120,9 @@ def test_melscale_load_save(self): self.assertEqual(fb_copy.size(), (1000, 128)) self.assertTrue(torch.allclose(fb, fb_copy)) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MelSpectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MelSpectrogram) + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.MelSpectrogram, tensor) def test_melspectrogram_load_save(self): waveform = self.waveform.float() @@ -188,11 +182,9 @@ def test_mel2(self): self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MFCC(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MFCC) + tensor = torch.rand((1, 1000)) + _test_script_module(transforms.MFCC, tensor) def test_mfcc(self): audio_orig = self.waveform.clone() @@ -397,17 +389,13 @@ def test_batch_compute_deltas(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawEncoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawEncoding) + tensor = torch.rand((1, 10)) + _test_script_module(transforms.MuLawEncoding, tensor) - @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawDecoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawDecoding) + tensor = torch.rand((1, 10)) + _test_script_module(transforms.MuLawDecoding, tensor) def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 From ba13f8f2b2ca7c7a2711cc0215d04d1e7fdbc8e5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 13:27:42 -0500 Subject: [PATCH 06/16] remove standalone jit file. --- test/test_jit.py | 155 ----------------------------------------------- 1 file changed, 155 deletions(-) delete mode 100644 test/test_jit.py diff --git a/test/test_jit.py b/test/test_jit.py deleted file mode 100644 index 4d0911c955..0000000000 --- a/test/test_jit.py +++ /dev/null @@ -1,155 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals -import torch -import torchaudio.functional as F -import torchaudio.transforms as transforms -import unittest - -RUN_CUDA = torch.cuda.is_available() -print('Run test with cuda:', RUN_CUDA) - - -class Test_JIT(unittest.TestCase): - def _get_script_module(self, f, *args): - # takes a transform function `f` and wraps it in a script module - class MyModule(torch.jit.ScriptModule): - def __init__(self): - super(MyModule, self).__init__() - self.module = f(*args) - self.module.eval() - - @torch.jit.script_method - def forward(self, tensor): - return self.module(tensor) - - return MyModule() - - def _test_script_module(self, tensor, f, *args): - # tests a script module that wraps a transform function `f` by feeding - # the tensor into the forward function - jit_out = self._get_script_module(f, *args).cuda()(tensor) - py_out = f(*args).cuda()(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_Spectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.Spectrogram) - - def test_torchscript_create_fb_matrix(self): - @torch.jit.script - def jit_method(n_stft, f_min, f_max, n_mels, sample_rate): - # type: (int, float, float, int, int) -> Tensor - return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate) - - n_stft = 100 - f_min = 0. - f_max = 20. - n_mels = 10 - sample_rate = 16000 - - jit_out = jit_method(n_stft, f_min, f_max, n_mels, sample_rate) - py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MelScale(self): - spec_f = torch.rand((1, 6, 201), device="cuda") - - self._test_script_module(spec_f, transforms.MelScale) - - def test_torchscript_amplitude_to_DB(self): - @torch.jit.script - def jit_method(spec, multiplier, amin, db_multiplier, top_db): - # type: (Tensor, float, float, float, Optional[float]) -> Tensor - return F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) - - spec = torch.rand((6, 201)) - multiplier = 10. - amin = 1e-10 - db_multiplier = 0. - top_db = 80. - - jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db) - py_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_AmplitudeToDB(self): - spec = torch.rand((6, 201), device="cuda") - - self._test_script_module(spec, transforms.AmplitudeToDB) - - def test_torchscript_create_dct(self): - @torch.jit.script - def jit_method(n_mfcc, n_mels, norm): - # type: (int, int, Optional[str]) -> Tensor - return F.create_dct(n_mfcc, n_mels, norm) - - n_mfcc = 40 - n_mels = 128 - norm = 'ortho' - - jit_out = jit_method(n_mfcc, n_mels, norm) - py_out = F.create_dct(n_mfcc, n_mels, norm) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MFCC(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MFCC) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MelSpectrogram(self): - tensor = torch.rand((1, 1000), device="cuda") - - self._test_script_module(tensor, transforms.MelSpectrogram) - - def test_torchscript_mu_law_encoding(self): - @torch.jit.script - def jit_method(tensor, qc): - # type: (Tensor, int) -> Tensor - return F.mu_law_encoding(tensor, qc) - - tensor = torch.rand((1, 10)) - qc = 256 - - jit_out = jit_method(tensor, qc) - py_out = F.mu_law_encoding(tensor, qc) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MuLawEncoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawEncoding) - - def test_torchscript_mu_law_decoding(self): - @torch.jit.script - def jit_method(tensor, qc): - # type: (Tensor, int) -> Tensor - return F.mu_law_decoding(tensor, qc) - - tensor = torch.rand((1, 10)) - qc = 256 - - jit_out = jit_method(tensor, qc) - py_out = F.mu_law_decoding(tensor, qc) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_MuLawDecoding(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.MuLawDecoding) - - -if __name__ == '__main__': - unittest.main() From 4c954a7d4c2e36cbc23af9e166be8a6dd5802492 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 14:14:53 -0500 Subject: [PATCH 07/16] update mel scale. --- torchaudio/transforms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 1b361ce526..cd16d3cd77 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -126,9 +126,11 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N super(MelScale, self).__init__() self.n_mels = n_mels self.sample_rate = sample_rate - self.f_max = f_max if f_max is not None else float(sample_rate // 2) - assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + self.f_max = f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min + + assert f_min <= f_max, 'Require f_min: %f < f_max: %f' % (f_min, f_max) + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) self.fb = torch.jit.Attribute(fb, torch.Tensor) From a68a39012386fb8d25385760a1328a5d6e21477c Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 14:52:58 -0500 Subject: [PATCH 08/16] explictly converting to float. --- torchaudio/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index cd16d3cd77..e5f0ea9cd6 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -126,10 +126,10 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N super(MelScale, self).__init__() self.n_mels = n_mels self.sample_rate = sample_rate - self.f_max = f_max = f_max if f_max is not None else float(sample_rate // 2) + self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min - assert f_min <= f_max, 'Require f_min: %f < f_max: %f' % (f_min, f_max) + assert float(f_min) <= float(self.f_max), 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) From 278435518cd78eb6bc0ad458f460a2e0b15a887d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 15:44:19 -0500 Subject: [PATCH 09/16] debug. --- torchaudio/transforms.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e5f0ea9cd6..31cd153655 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -129,7 +129,13 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min - assert float(f_min) <= float(self.f_max), 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + print(f_min) + print(self.f_max) + + a = float(f_min) + b = float(self.f_max) + + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) From f09ba2d78e94e4e711b393a41bdf3a27609fbc53 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 16:43:50 -0500 Subject: [PATCH 10/16] use attribute and scriptmodule in one place. --- test/test_transforms.py | 13 +------------ torchaudio/transforms.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c660a0c502..ca8d7605fa 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -21,18 +21,7 @@ def _get_script_module(f, *args): - # takes a transform function `f` and wraps it in a script module - class MyModule(torch.jit.ScriptModule): - def __init__(self): - super(MyModule, self).__init__() - self.module = f(*args) - self.module.eval() - - @torch.jit.script_method - def forward(self, tensor): - return self.module(tensor) - - return MyModule() + return torch.jit.script(f()) def _test_script_module(f, tensor, *args): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 31cd153655..184fd012d0 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -48,7 +48,7 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.window = torch.jit.Attribute(window, torch.Tensor) + self.window = window self.pad = pad self.power = power self.normalized = normalized @@ -67,7 +67,7 @@ def forward(self, waveform): self.win_length, self.power, self.normalized) -class AmplitudeToDB(torch.nn.Module): +class AmplitudeToDB(torch.jit.ScriptModule): r"""Turns a tensor from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input tensor, and so @@ -84,7 +84,7 @@ class AmplitudeToDB(torch.nn.Module): def __init__(self, stype='power', top_db=None): super(AmplitudeToDB, self).__init__() - self.stype = torch.jit.Attribute(stype, str) + self.stype = stype if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) @@ -129,17 +129,12 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min - print(f_min) - print(self.f_max) - - a = float(f_min) - b = float(self.f_max) assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) - self.fb = torch.jit.Attribute(fb, torch.Tensor) + self.fb = fb def forward(self, specgram): r""" @@ -199,7 +194,7 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.pad = pad self.n_mels = n_mels # number of mel frequency bins - self.f_max = torch.jit.Attribute(f_max, Optional[float]) + self.f_max = f_max self.f_min = f_min self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, @@ -251,7 +246,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m self.sample_rate = sample_rate self.n_mfcc = n_mfcc self.dct_type = dct_type - self.norm = torch.jit.Attribute(norm, Optional[str]) + self.norm = norm self.top_db = 80.0 self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) @@ -263,7 +258,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError('Cannot select more MFCC coefficients than # mel bins') dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) - self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) + self.dct_mat = dct_mat self.log_mels = log_mels def forward(self, waveform): @@ -403,7 +398,7 @@ class ComputeDeltas(torch.nn.Module): def __init__(self, win_length=5, mode="replicate"): super(ComputeDeltas, self).__init__() self.win_length = win_length - self.mode = torch.jit.Attribute(mode, str) + self.mode = mode def forward(self, specgram): r""" From d54fb61ed068f475d5ef04ac07091f2da4d665e5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 16:47:27 -0500 Subject: [PATCH 11/16] simplifying wrapper function. --- test/test_transforms.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ca8d7605fa..bddb578b3f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -20,16 +20,10 @@ print("Run test with cuda:", RUN_CUDA) -def _get_script_module(f, *args): - return torch.jit.script(f()) - - def _test_script_module(f, tensor, *args): - # tests a script module that wraps a transform function `f` by feeding - # the tensor into the forward function - jit_method = _get_script_module(f, *args) py_method = f(*args) + jit_method = torch.jit.script(py_method) jit_out = jit_method(tensor) py_out = py_method(tensor) From 69758d4a214bd7bdde2d5ac71d54d95f89357df1 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 16:52:35 -0500 Subject: [PATCH 12/16] refactor. --- test/test_transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index bddb578b3f..e8a2b7410d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -25,8 +25,8 @@ def _test_script_module(f, tensor, *args): py_method = f(*args) jit_method = torch.jit.script(py_method) - jit_out = jit_method(tensor) py_out = py_method(tensor) + jit_out = jit_method(tensor) assert torch.allclose(jit_out, py_out) @@ -34,11 +34,11 @@ def _test_script_module(f, tensor, *args): tensor = tensor.to("cuda") - jit_method = _get_script_module(f, *args).cuda() - py_method = f(*args).cuda() + py_method = py_method.cuda() + jit_method = torch.jit.script(py_method) - jit_out = jit_method(tensor) py_out = py_method(tensor) + jit_out = jit_method(tensor) assert torch.allclose(jit_out, py_out) From f9feac61376c56713ff30337f77e3b9e145209d5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 17:16:02 -0500 Subject: [PATCH 13/16] remove script decorator. --- torchaudio/functional.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index e2d3759f42..b06910f7c9 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -452,7 +452,6 @@ def complex_norm(complex_tensor, power=1.0): return torch.norm(complex_tensor, 2, -1).pow(power) -@torch.jit.script def angle(complex_tensor): # type: (Tensor) -> Tensor r"""Compute the angle of complex tensor input. @@ -466,7 +465,6 @@ def angle(complex_tensor): return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) -@torch.jit.script def magphase(complex_tensor, power=1.0): # type: (Tensor, float) -> Tuple[Tensor, Tensor] r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. @@ -547,7 +545,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): return complex_specgrams_stretch -@torch.jit.script def lfilter(waveform, a_coeffs, b_coeffs): # type: (Tensor, Tensor, Tensor) -> Tensor r""" @@ -622,7 +619,6 @@ def lfilter(waveform, a_coeffs, b_coeffs): return output -@torch.jit.script def biquad(waveform, b0, b1, b2, a0, a1, a2): # type: (Tensor, float, float, float, float, float, float) -> Tensor r"""Performs a biquad filter of input tensor. Initial conditions set to 0. @@ -657,7 +653,6 @@ def _dB2Linear(x): return math.exp(x * math.log(10) / 20.0) -@torch.jit.script def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): # type: (Tensor, int, float, float) -> Tensor r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation. @@ -687,7 +682,6 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -@torch.jit.script def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): # type: (Tensor, int, float, float) -> Tensor r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation. @@ -866,7 +860,6 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): return output -@torch.jit.script def _compute_nccf(waveform, sample_rate, frame_time, freq_low): # type: (Tensor, int, float, int) -> Tensor r""" From a632aabe429a911bfba082ebd84c9c42332e8ce2 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 17:56:28 -0500 Subject: [PATCH 14/16] flake8. --- test/test_functional.py | 2 +- test/test_functional_filtering.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index b8de12cd8c..4e36b4dce2 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -492,7 +492,7 @@ def test_torchscript_mu_law_decoding(self): _test_torchscript_functional(F.mu_law_decoding, tensor, qc) - def test_torchscript_mu_law_decoding(self): + def test_torchscript_complex_norm(self): complex_tensor = torch.randn(1, 2, 1025, 400, 2), power = 2 diff --git a/test/test_functional_filtering.py b/test/test_functional_filtering.py index e58894cb63..b90c41278b 100644 --- a/test/test_functional_filtering.py +++ b/test/test_functional_filtering.py @@ -197,7 +197,7 @@ def test_perf_biquad_filtering(self): assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) _test_torchscript_functional( - F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) + F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) ) From 3f93896a69dcdf60c8b8d17d382300ee8ede68f1 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 19 Nov 2019 17:58:00 -0500 Subject: [PATCH 15/16] flake8. --- torchaudio/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 184fd012d0..c7e7bfa0ba 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -129,7 +129,6 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min - assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( From df70da4e0aac1013d01c422c53f4aa39e32fae6d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 20 Nov 2019 09:52:15 -0500 Subject: [PATCH 16/16] apply to augmentations too. --- test/test_transforms.py | 20 ++++++++++++++++++-- torchaudio/augmentations.py | 8 +++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index e8a2b7410d..bc82f0a673 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,6 +4,7 @@ import torch import torchaudio +import torchaudio.augmentations as A import torchaudio.transforms as transforms import torchaudio.functional as F from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY @@ -20,9 +21,9 @@ print("Run test with cuda:", RUN_CUDA) -def _test_script_module(f, tensor, *args): +def _test_script_module(f, tensor, *args, **kwargs): - py_method = f(*args) + py_method = f(*args, **kwargs) jit_method = torch.jit.script(py_method) py_out = py_method(tensor) @@ -418,6 +419,21 @@ def test_batch_spectrogram(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_scriptmodule_TimeStretch(self): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.rand((10, 2, n_freq, 10, 2)) + _test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) + + def test_scriptmodule_FrequencyMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + _test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) + + def test_scriptmodule_TimeMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + _test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index efa07e9fd8..ed1573bdac 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -24,14 +24,13 @@ class TimeStretch(torch.jit.ScriptModule): def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): super(TimeStretch, self).__init__() + self.fixed_rate = fixed_rate + n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 - self.fixed_rate = fixed_rate phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] - self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) - @torch.jit.script_method def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor r""" @@ -63,7 +62,7 @@ def forward(self, complex_specgrams, overriding_rate=None): return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) -class _AxisMasking(torch.jit.ScriptModule): +class _AxisMasking(torch.nn.Module): r""" Apply masking to a spectrogram. Args: @@ -80,7 +79,6 @@ def __init__(self, mask_param, axis, iid_masks): self.axis = axis self.iid_masks = iid_masks - @torch.jit.script_method def forward(self, specgram, mask_value=0.): # type: (Tensor, float) -> Tensor r"""