Skip to content

Commit ac8a8f0

Browse files
author
Brian White
committed
[#1511] Add deprecation warning to MelScale for unset weight
Issue a warning if `n_stft` is unitialized or zero in construction. #1511
1 parent b8b732a commit ac8a8f0

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

test/torchaudio_unittest/transforms/transforms_test_impl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import torch
24
import torchaudio.transforms as T
35

@@ -39,7 +41,7 @@ def test_InverseMelScale(self):
3941
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
4042
n_fft=n_fft, power=power).to(self.device, self.dtype)
4143
input = T.MelScale(
42-
n_mels=n_mels, sample_rate=sample_rate
44+
n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft
4345
).to(self.device, self.dtype)(expected)
4446

4547
# Run transform
@@ -59,3 +61,9 @@ def test_InverseMelScale(self):
5961
assert _get_ratio(relative_diff < 1e-1) > 0.2
6062
assert _get_ratio(relative_diff < 1e-3) > 5e-3
6163
assert _get_ratio(relative_diff < 1e-5) > 1e-5
64+
65+
def test_melscale_unset_weight_warning(self):
66+
with warnings.catch_warnings(record=True) as caught_warnings:
67+
warnings.simplefilter("always")
68+
T.MelScale(n_mels=64, sample_rate=8000)
69+
assert len(caught_warnings) == 1

torchaudio/transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ def __init__(self,
283283

284284
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
285285

286+
if n_stft is None or n_stft == 0:
287+
warnings.warn(
288+
'Initialization of torchaudio.transforms.MelScale with an unset weight '
289+
'`n_stft` is deprecated and will be removed from a future release. '
290+
'Please refer to https://github.com/pytorch/audio/issues/1510 '
291+
'for more details.'
292+
)
293+
286294
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
287295
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
288296
self.mel_scale)

0 commit comments

Comments
 (0)