diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 85e6be082e..4d03789d32 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -33,7 +33,7 @@ def test_batch_Resample(self): self.assertEqual(computed, expected) def test_batch_MelScale(self): - specgram = torch.randn(2, 31, 2786) + specgram = torch.randn(2, 201, 256) # Single then transform then batch expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1) @@ -41,7 +41,7 @@ def test_batch_MelScale(self): # Batch then transform computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) - # shape = (3, 2, 201, 1394) + # shape = (3, 2, 128, 256) self.assertEqual(computed, expected) def test_batch_InverseMelScale(self): diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 370708dd21..ff640b645e 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -59,10 +59,6 @@ def test_AmplitudeToDB(self): spec = torch.rand((6, 201)) self._assert_consistency(T.AmplitudeToDB(), spec) - def test_MelScale_invalid(self): - with self.assertRaises(ValueError): - torch.jit.script(T.MelScale()) - def test_MelScale(self): spec_f = torch.rand((1, 201, 6)) self._assert_consistency(T.MelScale(n_stft=201), spec_f) diff --git a/test/torchaudio_unittest/transforms/transforms_test.py b/test/torchaudio_unittest/transforms/transforms_test.py index 3681d18199..a94ac46d01 100644 --- a/test/torchaudio_unittest/transforms/transforms_test.py +++ b/test/torchaudio_unittest/transforms/transforms_test.py @@ -55,17 +55,17 @@ def test_AmplitudeToDB(self): self.assertEqual(mag_to_db_torch, power_to_db_torch) def test_melscale_load_save(self): - specgram = torch.ones(1, 1000, 100) + specgram = torch.ones(1, 201, 100) melscale_transform = transforms.MelScale() melscale_transform(specgram) - melscale_transform_copy = transforms.MelScale(n_stft=1000) + melscale_transform_copy = transforms.MelScale() melscale_transform_copy.load_state_dict(melscale_transform.state_dict()) fb = melscale_transform.fb fb_copy = melscale_transform_copy.fb - self.assertEqual(fb_copy.size(), (1000, 128)) + self.assertEqual(fb_copy.size(), (201, 128)) self.assertEqual(fb, fb_copy) def test_melspectrogram_load_save(self): diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 95293cb40a..8284a70be3 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -1,5 +1,3 @@ -import warnings - import torch import torchaudio.transforms as T @@ -63,22 +61,6 @@ def test_InverseMelScale(self): assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-5) > 1e-5 - def test_melscale_unset_weight_warning(self): - """Issue a warning if MelScale initialized without a weight - - As part of the deprecation of lazy intialization behavior (#1510), - issue a warning if `n_stft` is not set. - """ - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always") - T.MelScale(n_mels=64, sample_rate=8000) - assert len(caught_warnings) == 1 - - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter("always") - T.MelScale(n_mels=64, sample_rate=8000, n_stft=201) - assert len(caught_warnings) == 0 - @nested_params( ["sinc_interpolation", "kaiser_window"], [16000, 44100], diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 592302b0e4..6d56dd5efb 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -244,9 +244,8 @@ class MelScale(torch.nn.Module): sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) f_min (float, optional): Minimum frequency. (Default: ``0.``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) - n_stft (int, optional): Number of bins in STFT. Calculated from first input - if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) - norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band + n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band (area normalization). (Default: ``None``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) """ @@ -257,7 +256,7 @@ def __init__(self, sample_rate: int = 16000, f_min: float = 0., f_max: Optional[float] = None, - n_stft: Optional[int] = None, + n_stft: int = 201, norm: Optional[str] = None, mel_scale: str = "htk") -> None: super(MelScale, self).__init__() @@ -269,35 +268,11 @@ def __init__(self, self.mel_scale = mel_scale assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) - - if n_stft is None or n_stft == 0: - warnings.warn( - 'Initialization of torchaudio.transforms.MelScale with an unset weight ' - '`n_stft=None` is deprecated and will be removed in release 0.10. ' - 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' - 'Refer to https://github.com/pytorch/audio/issues/1510 ' - 'for more details.' - ) - - fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( + fb = F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) self.register_buffer('fb', fb) - def __prepare_scriptable__(self): - r"""If `self.fb` is empty, the `forward` method will try to resize the parameter, - which does not work once the transform is scripted. However, this error does not happen - until the transform is executed. This is inconvenient especially if the resulting - TorchScript object is executed in other environments. Therefore, we check the - validity of `self.fb` here and fail if the resulting TS does not work. - - Returns: - MelScale: self - """ - if self.fb.numel() == 0: - raise ValueError("n_stft must be provided at construction") - return self - def forward(self, specgram: Tensor) -> Tensor: r""" Args: @@ -311,14 +286,6 @@ def forward(self, specgram: Tensor) -> Tensor: shape = specgram.size() specgram = specgram.reshape(-1, shape[-2], shape[-1]) - if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, - self.n_mels, self.sample_rate, self.norm, - self.mel_scale) - # Attributes cannot be reassigned outside __init__ so workaround - self.fb.resize_(tmp_fb.size()) - self.fb.copy_(tmp_fb) - # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)