@@ -23,8 +23,7 @@ def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs)
2323 torch .random .manual_seed (42 )
2424 computed = functional (tensors .clone (), * args , ** kwargs )
2525
26- assert expected .shape == computed .shape , (expected .shape , computed .shape )
27- assert torch .allclose (expected , computed , atol = atol , rtol = rtol )
26+ torch .testing .assert_allclose (computed , expected , rtol = rtol , atol = atol )
2827
2928 return tensors , expected
3029
@@ -43,8 +42,7 @@ def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
4342 torch .random .manual_seed (42 )
4443 computed = functional (tensors .clone (), * args , ** kwargs )
4544
46- assert expected .shape == computed .shape , (expected .shape , computed .shape )
47- assert torch .allclose (expected , computed , atol = atol , rtol = rtol )
45+ torch .testing .assert_allclose (computed , expected , rtol = rtol , atol = atol )
4846
4947
5048class TestFunctional (unittest .TestCase ):
@@ -96,8 +94,7 @@ def test_batch_AmplitudeToDB(self):
9694 # Batch then transform
9795 computed = torchaudio .transforms .AmplitudeToDB ()(spec .repeat (3 , 1 , 1 ))
9896
99- assert computed .shape == expected .shape , (computed .shape , expected .shape )
100- assert torch .allclose (computed , expected )
97+ torch .testing .assert_allclose (computed , expected )
10198
10299 def test_batch_Resample (self ):
103100 waveform = torch .randn (2 , 2786 )
@@ -108,8 +105,7 @@ def test_batch_Resample(self):
108105 # Batch then transform
109106 computed = torchaudio .transforms .Resample ()(waveform .repeat (3 , 1 , 1 ))
110107
111- assert computed .shape == expected .shape , (computed .shape , expected .shape )
112- assert torch .allclose (computed , expected )
108+ torch .testing .assert_allclose (computed , expected )
113109
114110 def test_batch_MelScale (self ):
115111 specgram = torch .randn (2 , 31 , 2786 )
@@ -121,8 +117,7 @@ def test_batch_MelScale(self):
121117 computed = torchaudio .transforms .MelScale ()(specgram .repeat (3 , 1 , 1 , 1 ))
122118
123119 # shape = (3, 2, 201, 1394)
124- assert computed .shape == expected .shape , (computed .shape , expected .shape )
125- assert torch .allclose (computed , expected )
120+ torch .testing .assert_allclose (computed , expected )
126121
127122 def test_batch_InverseMelScale (self ):
128123 n_mels = 32
@@ -136,11 +131,10 @@ def test_batch_InverseMelScale(self):
136131 computed = torchaudio .transforms .InverseMelScale (n_stft , n_mels )(mel_spec .repeat (3 , 1 , 1 , 1 ))
137132
138133 # shape = (3, 2, n_mels, 32)
139- assert computed .shape == expected .shape , (computed .shape , expected .shape )
140134
141135 # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
142136 # exactly same result. For this reason, tolerance is very relaxed here.
143- assert torch .allclose (computed , expected , atol = 1.0 )
137+ torch .testing . assert_allclose (computed , expected , atol = 1.0 , rtol = 1e-5 )
144138
145139 def test_batch_compute_deltas (self ):
146140 specgram = torch .randn (2 , 31 , 2786 )
@@ -152,8 +146,7 @@ def test_batch_compute_deltas(self):
152146 computed = torchaudio .transforms .ComputeDeltas ()(specgram .repeat (3 , 1 , 1 , 1 ))
153147
154148 # shape = (3, 2, 201, 1394)
155- assert computed .shape == expected .shape , (computed .shape , expected .shape )
156- assert torch .allclose (computed , expected )
149+ torch .testing .assert_allclose (computed , expected )
157150
158151 def test_batch_mulaw (self ):
159152 test_filepath = os .path .join (
@@ -169,8 +162,7 @@ def test_batch_mulaw(self):
169162 computed = torchaudio .transforms .MuLawEncoding ()(waveform_batched )
170163
171164 # shape = (3, 2, 201, 1394)
172- assert computed .shape == expected .shape , (computed .shape , expected .shape )
173- assert torch .allclose (computed , expected )
165+ torch .testing .assert_allclose (computed , expected )
174166
175167 # Single then transform then batch
176168 waveform_decoded = torchaudio .transforms .MuLawDecoding ()(waveform_encoded )
@@ -180,8 +172,7 @@ def test_batch_mulaw(self):
180172 computed = torchaudio .transforms .MuLawDecoding ()(computed )
181173
182174 # shape = (3, 2, 201, 1394)
183- assert computed .shape == expected .shape , (computed .shape , expected .shape )
184- assert torch .allclose (computed , expected )
175+ torch .testing .assert_allclose (computed , expected )
185176
186177 def test_batch_spectrogram (self ):
187178 test_filepath = os .path .join (
@@ -193,9 +184,7 @@ def test_batch_spectrogram(self):
193184
194185 # Batch then transform
195186 computed = torchaudio .transforms .Spectrogram ()(waveform .repeat (3 , 1 , 1 ))
196-
197- assert computed .shape == expected .shape , (computed .shape , expected .shape )
198- assert torch .allclose (computed , expected )
187+ torch .testing .assert_allclose (computed , expected )
199188
200189 def test_batch_melspectrogram (self ):
201190 test_filepath = os .path .join (
@@ -207,9 +196,7 @@ def test_batch_melspectrogram(self):
207196
208197 # Batch then transform
209198 computed = torchaudio .transforms .MelSpectrogram ()(waveform .repeat (3 , 1 , 1 ))
210-
211- assert computed .shape == expected .shape , (computed .shape , expected .shape )
212- assert torch .allclose (computed , expected )
199+ torch .testing .assert_allclose (computed , expected )
213200
214201 @unittest .skipIf ("sox" not in BACKENDS , "sox not available" )
215202 @AudioBackendScope ("sox" )
@@ -223,9 +210,7 @@ def test_batch_mfcc(self):
223210
224211 # Batch then transform
225212 computed = torchaudio .transforms .MFCC ()(waveform .repeat (3 , 1 , 1 ))
226-
227- assert computed .shape == expected .shape , (computed .shape , expected .shape )
228- assert torch .allclose (computed , expected , atol = 1e-5 )
213+ torch .testing .assert_allclose (computed , expected , atol = 1e-5 , rtol = 1e-5 )
229214
230215 def test_batch_TimeStretch (self ):
231216 test_filepath = os .path .join (
@@ -260,8 +245,7 @@ def test_batch_TimeStretch(self):
260245 hop_length = 512 ,
261246 )(complex_specgrams .repeat (3 , 1 , 1 , 1 , 1 ))
262247
263- assert computed .shape == expected .shape , (computed .shape , expected .shape )
264- assert torch .allclose (computed , expected , atol = 1e-5 )
248+ torch .testing .assert_allclose (computed , expected , atol = 1e-5 , rtol = 1e-5 )
265249
266250 def test_batch_Fade (self ):
267251 test_filepath = os .path .join (
@@ -275,9 +259,7 @@ def test_batch_Fade(self):
275259
276260 # Batch then transform
277261 computed = torchaudio .transforms .Fade (fade_in_len , fade_out_len )(waveform .repeat (3 , 1 , 1 ))
278-
279- assert computed .shape == expected .shape , (computed .shape , expected .shape )
280- assert torch .allclose (computed , expected )
262+ torch .testing .assert_allclose (computed , expected )
281263
282264 def test_batch_Vol (self ):
283265 test_filepath = os .path .join (
@@ -289,9 +271,7 @@ def test_batch_Vol(self):
289271
290272 # Batch then transform
291273 computed = torchaudio .transforms .Vol (gain = 1.1 )(waveform .repeat (3 , 1 , 1 ))
292-
293- assert computed .shape == expected .shape , (computed .shape , expected .shape )
294- assert torch .allclose (computed , expected )
274+ torch .testing .assert_allclose (computed , expected )
295275
296276
297277if __name__ == '__main__' :
0 commit comments