diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index 96aac4e770..4fbe4d871e 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -21,6 +21,23 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): device = torch.device('cpu') +class TestCreateFBMatrix(common_utils.TorchaudioTestCase): + def test_no_warning_high_n_freq(self): + with pytest.warns(None) as w: + F.create_fb_matrix(288, 0, 8000, 128, 16000) + assert len(w) == 0 + + def test_no_warning_low_n_mels(self): + with pytest.warns(None) as w: + F.create_fb_matrix(201, 0, 8000, 89, 16000) + assert len(w) == 0 + + def test_warning(self): + with pytest.warns(None) as w: + F.create_fb_matrix(201, 0, 8000, 128, 16000) + assert len(w) == 1 + + class TestComputeDeltas(common_utils.TorchaudioTestCase): """Test suite for correctness of compute_deltas""" def test_one_channel(self): diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 4e8e62848e..a6514c79f2 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -313,6 +313,13 @@ def create_fb_matrix( enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) fb *= enorm.unsqueeze(0) + if (fb.max(dim=0).values == 0.).any(): + warnings.warn( + "At least one mel filterbank has all zero values. " + f"The value for `n_mels` ({n_mels}) may be set too high. " + f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." + ) + return fb