From 88d31426f5107005ff01cb5982cd237bf3f1fb39 Mon Sep 17 00:00:00 2001 From: Brian White Date: Wed, 19 May 2021 13:14:11 +0100 Subject: [PATCH 1/3] [#1511] Add deprecation warning to MelScale for unset weight Issue a warning if `n_stft` is unitialized or zero in construction. https://github.com/pytorch/audio/issues/1511 --- .../transforms/transforms_test_impl.py | 20 ++++++++++++++++++- torchaudio/transforms.py | 9 +++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 8a92662200..44f254e540 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -1,3 +1,5 @@ +import warnings + import torch import torchaudio.transforms as T @@ -39,7 +41,7 @@ def test_InverseMelScale(self): get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), n_fft=n_fft, power=power).to(self.device, self.dtype) input = T.MelScale( - n_mels=n_mels, sample_rate=sample_rate + n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft ).to(self.device, self.dtype)(expected) # Run transform @@ -59,3 +61,19 @@ def test_InverseMelScale(self): assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-5) > 1e-5 + + def test_melscale_unset_weight_warning(self): + """Issue a warning if MelScale initialized without a weight + + As part of the deprecation of lazy intialization behavior (#1510), + issue a warning if `n_stft` is not set. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000) + assert len(caught_warnings) == 1 + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + T.MelScale(n_mels=64, sample_rate=8000, n_stft=201) + assert len(caught_warnings) == 0 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4704c4eb39..b14b3b93ad 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -283,6 +283,15 @@ def __init__(self, assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) + if n_stft is None or n_stft == 0: + warnings.warn( + 'Initialization of torchaudio.transforms.MelScale with an unset weight ' + '`n_stft` is deprecated and will be removed from a future release. ' + 'Please set a proper n_stftvalue. Typically this isn_fft // 2 + 1. ' + 'Refer to https://github.com/pytorch/audio/issues/1510 ' + 'for more details.' + ) + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) From 7175f88705e40cee5b91b29d938293e8b70e50aa Mon Sep 17 00:00:00 2001 From: Brian White Date: Wed, 19 May 2021 17:53:58 +0100 Subject: [PATCH 2/3] Update torchaudio/transforms.py Co-authored-by: moto <855818+mthrok@users.noreply.github.com> --- torchaudio/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index b14b3b93ad..2749be2132 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -287,7 +287,7 @@ def __init__(self, warnings.warn( 'Initialization of torchaudio.transforms.MelScale with an unset weight ' '`n_stft` is deprecated and will be removed from a future release. ' - 'Please set a proper n_stftvalue. Typically this isn_fft // 2 + 1. ' + 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' 'Refer to https://github.com/pytorch/audio/issues/1510 ' 'for more details.' ) From fd15a50c1e8ebbd7863218a1448c3d71ab47d0b7 Mon Sep 17 00:00:00 2001 From: Brian White Date: Wed, 19 May 2021 17:54:09 +0100 Subject: [PATCH 3/3] Update torchaudio/transforms.py Co-authored-by: moto <855818+mthrok@users.noreply.github.com> --- torchaudio/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 2749be2132..66faa0d753 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -286,7 +286,7 @@ def __init__(self, if n_stft is None or n_stft == 0: warnings.warn( 'Initialization of torchaudio.transforms.MelScale with an unset weight ' - '`n_stft` is deprecated and will be removed from a future release. ' + '`n_stft=None` is deprecated and will be removed from a future release. ' 'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. ' 'Refer to https://github.com/pytorch/audio/issues/1510 ' 'for more details.'