Skip to content

Commit 1e7d8d2

Browse files
krishnakalyan3krishnakalyan3
andauthored
Replace pytest's paremeterization with parameterized (#1157)
Also replaces `assert_allclose` with `assertEqual`. Co-authored-by: krishnakalyan3 <[email protected]>
1 parent df48ba3 commit 1e7d8d2

File tree

2 files changed

+94
-79
lines changed

2 files changed

+94
-79
lines changed

test/torchaudio_unittest/functional/functional_cpu_test.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torchaudio.functional as F
66
from parameterized import parameterized
77
import pytest
8+
import itertools
89

910
from torchaudio_unittest import common_utils
1011
from .functional_impl import Lfilter, Spectrogram
@@ -53,15 +54,15 @@ def test_one_channel(self):
5354
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
5455
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
5556
computed = F.compute_deltas(specgram, win_length=3)
56-
torch.testing.assert_allclose(computed, expected)
57+
self.assertEqual(computed, expected)
5758

5859
def test_two_channels(self):
5960
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
6061
[1.0, 2.0, 3.0, 4.0]]])
6162
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
6263
[0.5, 1.0, 1.0, 0.5]]])
6364
computed = F.compute_deltas(specgram, win_length=3)
64-
torch.testing.assert_allclose(computed, expected)
65+
self.assertEqual(computed, expected)
6566

6667

6768
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
@@ -97,13 +98,13 @@ def test_DB_to_amplitude(self):
9798
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
9899
x2 = F.DB_to_amplitude(db, ref, power)
99100

100-
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
101+
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
101102

102103
# Spectrogram amplitude -> DB -> amplitude
103104
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
104105
x2 = F.DB_to_amplitude(db, ref, power)
105106

106-
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
107+
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
107108

108109
# Waveform power -> DB -> power
109110
multiplier = 10.
@@ -112,61 +113,66 @@ def test_DB_to_amplitude(self):
112113
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
113114
x2 = F.DB_to_amplitude(db, ref, power)
114115

115-
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
116+
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
116117

117118
# Spectrogram power -> DB -> power
118119
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
119120
x2 = F.DB_to_amplitude(db, ref, power)
120121

121-
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
122+
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
122123

123124

124-
@pytest.mark.parametrize('complex_tensor', [
125-
torch.randn(1, 2, 1025, 400, 2),
126-
torch.randn(1025, 400, 2)
127-
])
128-
@pytest.mark.parametrize('power', [1, 2, 0.7])
129-
def test_complex_norm(complex_tensor, power):
130-
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
131-
norm_tensor = F.complex_norm(complex_tensor, power)
125+
class TestComplexNorm(common_utils.TorchaudioTestCase):
126+
@parameterized.expand(list(itertools.product(
127+
[(1, 2, 1025, 400, 2), (1025, 400, 2)],
128+
[1, 2, 0.7]
129+
)))
130+
def test_complex_norm(self, shape, power):
131+
torch.random.manual_seed(42)
132+
complex_tensor = torch.randn(*shape)
133+
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
134+
norm_tensor = F.complex_norm(complex_tensor, power)
135+
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
132136

133-
torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
134137

138+
class TestMaskAlongAxis(common_utils.TorchaudioTestCase):
139+
@parameterized.expand(list(itertools.product(
140+
[(2, 1025, 400), (1, 201, 100)],
141+
[100],
142+
[0., 30.],
143+
[1, 2]
144+
)))
145+
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
146+
torch.random.manual_seed(42)
147+
specgram = torch.randn(*shape)
148+
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
135149

136-
@pytest.mark.parametrize('specgram', [
137-
torch.randn(2, 1025, 400),
138-
torch.randn(1, 201, 100)
139-
])
140-
@pytest.mark.parametrize('mask_param', [100])
141-
@pytest.mark.parametrize('mask_value', [0., 30.])
142-
@pytest.mark.parametrize('axis', [1, 2])
143-
def test_mask_along_axis(specgram, mask_param, mask_value, axis):
150+
other_axis = 1 if axis == 2 else 2
144151

145-
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
152+
masked_columns = (mask_specgram == mask_value).sum(other_axis)
153+
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
154+
num_masked_columns //= mask_specgram.size(0)
146155

147-
other_axis = 1 if axis == 2 else 2
156+
assert mask_specgram.size() == specgram.size()
157+
assert num_masked_columns < mask_param
148158

149-
masked_columns = (mask_specgram == mask_value).sum(other_axis)
150-
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
151-
num_masked_columns //= mask_specgram.size(0)
152159

153-
assert mask_specgram.size() == specgram.size()
154-
assert num_masked_columns < mask_param
160+
class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
161+
@parameterized.expand(list(itertools.product(
162+
[100],
163+
[0., 30.],
164+
[2, 3]
165+
)))
166+
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
167+
torch.random.manual_seed(42)
168+
specgrams = torch.randn(4, 2, 1025, 400)
155169

170+
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
156171

157-
@pytest.mark.parametrize('mask_param', [100])
158-
@pytest.mark.parametrize('mask_value', [0., 30.])
159-
@pytest.mark.parametrize('axis', [2, 3])
160-
def test_mask_along_axis_iid(mask_param, mask_value, axis):
161-
torch.random.manual_seed(42)
162-
specgrams = torch.randn(4, 2, 1025, 400)
172+
other_axis = 2 if axis == 3 else 3
163173

164-
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
174+
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
175+
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
165176

166-
other_axis = 2 if axis == 3 else 3
167-
168-
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
169-
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
170-
171-
assert mask_specgrams.size() == specgrams.size()
172-
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
177+
assert mask_specgrams.size() == specgrams.size()
178+
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()

test/torchaudio_unittest/librosa_compatibility_test.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torchaudio
88
import torchaudio.functional as F
99
from torchaudio._internal.module_utils import is_module_available
10+
from parameterized import parameterized
11+
import itertools
1012

1113
LIBROSA_AVAILABLE = is_module_available('librosa')
1214

@@ -111,42 +113,49 @@ def test_amplitude_to_DB(self):
111113
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
112114

113115

114-
@pytest.mark.parametrize('complex_specgrams', [
115-
torch.randn(2, 1025, 400, 2)
116-
])
117-
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
118-
@pytest.mark.parametrize('hop_length', [256])
119116
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
120-
def test_phase_vocoder(complex_specgrams, rate, hop_length):
121-
# Due to cummulative sum, numerical error in using torch.float32 will
122-
# result in bottom right values of the stretched sectrogram to not
123-
# match with librosa.
124-
125-
complex_specgrams = complex_specgrams.type(torch.float64)
126-
phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None]
127-
128-
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
129-
130-
# == Test shape
131-
expected_size = list(complex_specgrams.size())
132-
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
133-
134-
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
135-
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
136-
137-
# == Test values
138-
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
139-
mono_complex_specgram = complex_specgrams[index].numpy()
140-
mono_complex_specgram = mono_complex_specgram[..., 0] + \
141-
mono_complex_specgram[..., 1] * 1j
142-
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
143-
rate=rate,
144-
hop_length=hop_length)
145-
146-
complex_stretch = complex_specgrams_stretch[index].numpy()
147-
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
148-
149-
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
117+
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
118+
@parameterized.expand(list(itertools.product(
119+
[(2, 1025, 400, 2)],
120+
[0.5, 1.01, 1.3],
121+
[256]
122+
)))
123+
def test_phase_vocoder(self, shape, rate, hop_length):
124+
# Due to cummulative sum, numerical error in using torch.float32 will
125+
# result in bottom right values of the stretched sectrogram to not
126+
# match with librosa.
127+
torch.random.manual_seed(42)
128+
complex_specgrams = torch.randn(*shape)
129+
complex_specgrams = complex_specgrams.type(torch.float64)
130+
phase_advance = torch.linspace(
131+
0,
132+
np.pi * hop_length,
133+
complex_specgrams.shape[-3],
134+
dtype=torch.float64)[..., None]
135+
136+
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
137+
138+
# == Test shape
139+
expected_size = list(complex_specgrams.size())
140+
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
141+
142+
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
143+
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
144+
145+
# == Test values
146+
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
147+
mono_complex_specgram = complex_specgrams[index].numpy()
148+
mono_complex_specgram = mono_complex_specgram[..., 0] + \
149+
mono_complex_specgram[..., 1] * 1j
150+
expected_complex_stretch = librosa.phase_vocoder(
151+
mono_complex_specgram,
152+
rate=rate,
153+
hop_length=hop_length)
154+
155+
complex_stretch = complex_specgrams_stretch[index].numpy()
156+
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
157+
158+
self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
150159

151160

152161
def _load_audio_asset(*asset_paths, **kwargs):

0 commit comments

Comments
 (0)