diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 8a92662200..44f254e540 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -1,3 +1,5 @@ +import warnings + import torch import torchaudio.transforms as T @@ -39,7 +41,7 @@ def test_InverseMelScale(self): get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), n_fft=n_fft, power=power).to(self.device, self.dtype) input = T.MelScale( - n_mels=n_mels, sample_rate=sample_rate + n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft ).to(self.device, self.dtype)(expected) # Run transform @@ -59,3 +61,19 @@ def test_InverseMelScale(self): assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-5) > 1e-5 + + def test_melscale_unset_weight_warning(self): + """Issue a warning if MelScale initialized without a weight + + As part of the deprecation of lazy intialization behavior (#1510), + issue a warning if `n_stft` is not set. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000) + assert len(caught_warnings) == 1 + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000, n_stft=201) + assert len(caught_warnings) == 0 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4704c4eb39..66faa0d753 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -283,6 +283,15 @@ def __init__(self, assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) + if n_stft is None or n_stft == 0: + warnings.warn( + 'Initialization of torchaudio.transforms.MelScale with an unset weight ' + '`n_stft=None` is deprecated and will be removed from a future release. ' + 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' + 'Refer to https://github.com/pytorch/audio/issues/1510 ' + 'for more details.' + ) + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)