From dbfbdce2bb7ced5e7a523e82e1cc2a9609f81ca8 Mon Sep 17 00:00:00 2001 From: Pankaj Patil Date: Fri, 11 Jun 2021 19:19:43 +0530 Subject: [PATCH 1/3] Remove lazy behavior from MelScale --- .../torchscript_consistency_impl.py | 4 --- torchaudio/transforms.py | 34 +------------------ 2 files changed, 1 insertion(+), 37 deletions(-) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 5258343181..506ca9af6c 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,10 +59,6 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) - def test_MelScale_invalid(self): - with self.assertRaises(ValueError): - torch.jit.script(T.MelScale()) - def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) self._assert_consistency(T.MelScale(n_stft=201), spec_f) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4d9ae9d736..cbb259b684 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -255,7 +255,7 @@ def __init__(self, sample_rate: int = 16000, f_min: float = 0., f_max: Optional[float] = None, - n_stft: Optional[int] = None, + n_stft: int = 201, norm: Optional[str] = None, mel_scale: str = "htk") -> None: super(MelScale, self).__init__() @@ -268,33 +268,8 @@ def __init__(self, assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - if n_stft is None or n_stft == 0: - warnings.warn( - 'Initialization of torchaudio.transforms.MelScale with an unset weight ' - '`n_stft=None` is deprecated and will be removed in release 0.10. ' - 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' - 'Refer to https://github.com/pytorch/audio/issues/1510 ' - 'for more details.' - ) - - fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( - n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, - self.mel_scale) self.register_buffer('fb', fb) - def __prepare_scriptable__(self): - r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, - which does not work once the transform is scripted. However, this error does not happen - until the transform is executed. This is inconvenient especially if the resulting - TorchScript object is executed in other environments. Therefore, we check the - validity of `self.fb` here and fail if the resulting TS does not work. - - Returns: - MelScale: self - """ - if self.fb.numel() == 0: - raise ValueError("n_stft must be provided at construction") - return self def forward(self, specgram: Tensor) -> Tensor: r""" @@ -309,13 +284,6 @@ def forward(self, specgram: Tensor) -> Tensor: shape = specgram.size() specgram = specgram.reshape(-1, shape[-2], shape[-1]) - if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, - self.n_mels, self.sample_rate, self.norm, - self.mel_scale) - # Attributes cannot be reassigned outside __init__ so workaround - self.fb.resize_(tmp_fb.size()) - self.fb.copy_(tmp_fb) # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) From 40b5a3f1448fa69fc4636518da7cdb738d5de903 Mon Sep 17 00:00:00 2001 From: Pankaj Patil Date: Mon, 21 Jun 2021 21:12:31 +0530 Subject: [PATCH 2/3] Implement Feedback --- torchaudio/transforms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index cbb259b684..7c97f3bb12 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -267,7 +267,9 @@ def __init__(self, self.mel_scale = mel_scale assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - + F.create_fb_matrix( + n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, + self.mel_scale) self.register_buffer('fb', fb) From 9a1a116e712c1cd32ed2b6f23d16e22e02e04e52 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 13 Jul 2021 21:36:20 -0400 Subject: [PATCH 3/3] Update torchaudio/transforms.py --- torchaudio/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 7c97f3bb12..cb668d16cb 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -267,7 +267,7 @@ def __init__(self, self.mel_scale = mel_scale assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - F.create_fb_matrix( + fb = F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb)