55import torchaudio .functional as F
66from parameterized import parameterized
77import pytest
8+ import itertools
89
910from torchaudio_unittest import common_utils
1011from .functional_impl import Lfilter , Spectrogram
@@ -53,15 +54,15 @@ def test_one_channel(self):
5354 specgram = torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ]]])
5455 expected = torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ]]])
5556 computed = F .compute_deltas (specgram , win_length = 3 )
56- torch . testing . assert_allclose (computed , expected )
57+ self . assertEqual (computed , expected )
5758
5859 def test_two_channels (self ):
5960 specgram = torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ],
6061 [1.0 , 2.0 , 3.0 , 4.0 ]]])
6162 expected = torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ],
6263 [0.5 , 1.0 , 1.0 , 0.5 ]]])
6364 computed = F .compute_deltas (specgram , win_length = 3 )
64- torch . testing . assert_allclose (computed , expected )
65+ self . assertEqual (computed , expected )
6566
6667
6768class TestDetectPitchFrequency (common_utils .TorchaudioTestCase ):
@@ -97,13 +98,13 @@ def test_DB_to_amplitude(self):
9798 db = F .amplitude_to_DB (torch .abs (x ), multiplier , amin , db_multiplier , top_db = None )
9899 x2 = F .DB_to_amplitude (db , ref , power )
99100
100- torch . testing . assert_allclose (x2 , torch .abs (x ), atol = 5e-5 , rtol = 1e-5 )
101+ self . assertEqual (x2 , torch .abs (x ), atol = 5e-5 , rtol = 1e-5 )
101102
102103 # Spectrogram amplitude -> DB -> amplitude
103104 db = F .amplitude_to_DB (spec , multiplier , amin , db_multiplier , top_db = None )
104105 x2 = F .DB_to_amplitude (db , ref , power )
105106
106- torch . testing . assert_allclose (x2 , spec , atol = 5e-5 , rtol = 1e-5 )
107+ self . assertEqual (x2 , spec , atol = 5e-5 , rtol = 1e-5 )
107108
108109 # Waveform power -> DB -> power
109110 multiplier = 10.
@@ -112,61 +113,66 @@ def test_DB_to_amplitude(self):
112113 db = F .amplitude_to_DB (x , multiplier , amin , db_multiplier , top_db = None )
113114 x2 = F .DB_to_amplitude (db , ref , power )
114115
115- torch . testing . assert_allclose (x2 , torch .abs (x ), atol = 5e-5 , rtol = 1e-5 )
116+ self . assertEqual (x2 , torch .abs (x ), atol = 5e-5 , rtol = 1e-5 )
116117
117118 # Spectrogram power -> DB -> power
118119 db = F .amplitude_to_DB (spec , multiplier , amin , db_multiplier , top_db = None )
119120 x2 = F .DB_to_amplitude (db , ref , power )
120121
121- torch . testing . assert_allclose (x2 , spec , atol = 5e-5 , rtol = 1e-5 )
122+ self . assertEqual (x2 , spec , atol = 5e-5 , rtol = 1e-5 )
122123
123124
124- @pytest .mark .parametrize ('complex_tensor' , [
125- torch .randn (1 , 2 , 1025 , 400 , 2 ),
126- torch .randn (1025 , 400 , 2 )
127- ])
128- @pytest .mark .parametrize ('power' , [1 , 2 , 0.7 ])
129- def test_complex_norm (complex_tensor , power ):
130- expected_norm_tensor = complex_tensor .pow (2 ).sum (- 1 ).pow (power / 2 )
131- norm_tensor = F .complex_norm (complex_tensor , power )
125+ class TestComplexNorm (common_utils .TorchaudioTestCase ):
126+ @parameterized .expand (list (itertools .product (
127+ [(1 , 2 , 1025 , 400 , 2 ), (1025 , 400 , 2 )],
128+ [1 , 2 , 0.7 ]
129+ )))
130+ def test_complex_norm (self , shape , power ):
131+ torch .random .manual_seed (42 )
132+ complex_tensor = torch .randn (* shape )
133+ expected_norm_tensor = complex_tensor .pow (2 ).sum (- 1 ).pow (power / 2 )
134+ norm_tensor = F .complex_norm (complex_tensor , power )
135+ self .assertEqual (norm_tensor , expected_norm_tensor , atol = 1e-5 , rtol = 1e-5 )
132136
133- torch .testing .assert_allclose (norm_tensor , expected_norm_tensor , atol = 1e-5 , rtol = 1e-5 )
134137
138+ class TestMaskAlongAxis (common_utils .TorchaudioTestCase ):
139+ @parameterized .expand (list (itertools .product (
140+ [(2 , 1025 , 400 ), (1 , 201 , 100 )],
141+ [100 ],
142+ [0. , 30. ],
143+ [1 , 2 ]
144+ )))
145+ def test_mask_along_axis (self , shape , mask_param , mask_value , axis ):
146+ torch .random .manual_seed (42 )
147+ specgram = torch .randn (* shape )
148+ mask_specgram = F .mask_along_axis (specgram , mask_param , mask_value , axis )
135149
136- @pytest .mark .parametrize ('specgram' , [
137- torch .randn (2 , 1025 , 400 ),
138- torch .randn (1 , 201 , 100 )
139- ])
140- @pytest .mark .parametrize ('mask_param' , [100 ])
141- @pytest .mark .parametrize ('mask_value' , [0. , 30. ])
142- @pytest .mark .parametrize ('axis' , [1 , 2 ])
143- def test_mask_along_axis (specgram , mask_param , mask_value , axis ):
150+ other_axis = 1 if axis == 2 else 2
144151
145- mask_specgram = F .mask_along_axis (specgram , mask_param , mask_value , axis )
152+ masked_columns = (mask_specgram == mask_value ).sum (other_axis )
153+ num_masked_columns = (masked_columns == mask_specgram .size (other_axis )).sum ()
154+ num_masked_columns //= mask_specgram .size (0 )
146155
147- other_axis = 1 if axis == 2 else 2
156+ assert mask_specgram .size () == specgram .size ()
157+ assert num_masked_columns < mask_param
148158
149- masked_columns = (mask_specgram == mask_value ).sum (other_axis )
150- num_masked_columns = (masked_columns == mask_specgram .size (other_axis )).sum ()
151- num_masked_columns //= mask_specgram .size (0 )
152159
153- assert mask_specgram .size () == specgram .size ()
154- assert num_masked_columns < mask_param
160+ class TestMaskAlongAxisIID (common_utils .TorchaudioTestCase ):
161+ @parameterized .expand (list (itertools .product (
162+ [100 ],
163+ [0. , 30. ],
164+ [2 , 3 ]
165+ )))
166+ def test_mask_along_axis_iid (self , mask_param , mask_value , axis ):
167+ torch .random .manual_seed (42 )
168+ specgrams = torch .randn (4 , 2 , 1025 , 400 )
155169
170+ mask_specgrams = F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis )
156171
157- @pytest .mark .parametrize ('mask_param' , [100 ])
158- @pytest .mark .parametrize ('mask_value' , [0. , 30. ])
159- @pytest .mark .parametrize ('axis' , [2 , 3 ])
160- def test_mask_along_axis_iid (mask_param , mask_value , axis ):
161- torch .random .manual_seed (42 )
162- specgrams = torch .randn (4 , 2 , 1025 , 400 )
172+ other_axis = 2 if axis == 3 else 3
163173
164- mask_specgrams = F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis )
174+ masked_columns = (mask_specgrams == mask_value ).sum (other_axis )
175+ num_masked_columns = (masked_columns == mask_specgrams .size (other_axis )).sum (- 1 )
165176
166- other_axis = 2 if axis == 3 else 3
167-
168- masked_columns = (mask_specgrams == mask_value ).sum (other_axis )
169- num_masked_columns = (masked_columns == mask_specgrams .size (other_axis )).sum (- 1 )
170-
171- assert mask_specgrams .size () == specgrams .size ()
172- assert (num_masked_columns < mask_param ).sum () == num_masked_columns .numel ()
177+ assert mask_specgrams .size () == specgrams .size ()
178+ assert (num_masked_columns < mask_param ).sum () == num_masked_columns .numel ()
0 commit comments