diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 5258343181..506ca9af6c 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,10 +59,6 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) - def test_MelScale_invalid(self): - with self.assertRaises(ValueError): - torch.jit.script(T.MelScale()) - def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) self._assert_consistency(T.MelScale(n_stft=201), spec_f) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4d9ae9d736..cb668d16cb 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -255,7 +255,7 @@ def __init__(self, sample_rate: int = 16000, f_min: float = 0., f_max: Optional[float] = None, - n_stft: Optional[int] = None, + n_stft: int = 201, norm: Optional[str] = None, mel_scale: str = "htk") -> None: super(MelScale, self).__init__() @@ -267,34 +267,11 @@ def __init__(self, self.mel_scale = mel_scale assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - - if n_stft is None or n_stft == 0: - warnings.warn( - 'Initialization of torchaudio.transforms.MelScale with an unset weight ' - '`n_stft=None` is deprecated and will be removed in release 0.10. ' - 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' - 'Refer to https://github.com/pytorch/audio/issues/1510 ' - 'for more details.' - ) - - fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( + fb = F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb) - def __prepare_scriptable__(self): - r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, - which does not work once the transform is scripted. However, this error does not happen - until the transform is executed. This is inconvenient especially if the resulting - TorchScript object is executed in other environments. Therefore, we check the - validity of `self.fb` here and fail if the resulting TS does not work. - - Returns: - MelScale: self - """ - if self.fb.numel() == 0: - raise ValueError("n_stft must be provided at construction") - return self def forward(self, specgram: Tensor) -> Tensor: r""" @@ -309,13 +286,6 @@ def forward(self, specgram: Tensor) -> Tensor: shape = specgram.size() specgram = specgram.reshape(-1, shape[-2], shape[-1]) - if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, - self.n_mels, self.sample_rate, self.norm, - self.mel_scale) - # Attributes cannot be reassigned outside __init__ so workaround - self.fb.resize_(tmp_fb.size()) - self.fb.copy_(tmp_fb) # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...)