Skip to content

Commit 5f5df1d

Browse files
authored
Use torch.testing.assert_allclose (#513)
* grep -l 'torch.allclose' -r test | xargs sed -i 's/assert torch.allclose/torch.testing.assert_allclose/g' * grep -l 'torch.allclose' -r test | xargs sed -i 's/self.assertTrue(torch.allclose(\(.*\)))/torch.testing.assert_allclose(\1)/g' * Fix missing atol/rtol, wrong shape, argument order. Remove redundant shape assertions
1 parent bc1ffb1 commit 5f5df1d

File tree

7 files changed

+72
-96
lines changed

7 files changed

+72
-96
lines changed

test/test_batch_consistency.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs)
2323
torch.random.manual_seed(42)
2424
computed = functional(tensors.clone(), *args, **kwargs)
2525

26-
assert expected.shape == computed.shape, (expected.shape, computed.shape)
27-
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
26+
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
2827

2928
return tensors, expected
3029

@@ -43,8 +42,7 @@ def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
4342
torch.random.manual_seed(42)
4443
computed = functional(tensors.clone(), *args, **kwargs)
4544

46-
assert expected.shape == computed.shape, (expected.shape, computed.shape)
47-
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)
45+
torch.testing.assert_allclose(computed, expected, rtol=rtol, atol=atol)
4846

4947

5048
class TestFunctional(unittest.TestCase):
@@ -96,8 +94,7 @@ def test_batch_AmplitudeToDB(self):
9694
# Batch then transform
9795
computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
9896

99-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
100-
assert torch.allclose(computed, expected)
97+
torch.testing.assert_allclose(computed, expected)
10198

10299
def test_batch_Resample(self):
103100
waveform = torch.randn(2, 2786)
@@ -108,8 +105,7 @@ def test_batch_Resample(self):
108105
# Batch then transform
109106
computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))
110107

111-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
112-
assert torch.allclose(computed, expected)
108+
torch.testing.assert_allclose(computed, expected)
113109

114110
def test_batch_MelScale(self):
115111
specgram = torch.randn(2, 31, 2786)
@@ -121,8 +117,7 @@ def test_batch_MelScale(self):
121117
computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))
122118

123119
# shape = (3, 2, 201, 1394)
124-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
125-
assert torch.allclose(computed, expected)
120+
torch.testing.assert_allclose(computed, expected)
126121

127122
def test_batch_InverseMelScale(self):
128123
n_mels = 32
@@ -136,11 +131,10 @@ def test_batch_InverseMelScale(self):
136131
computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
137132

138133
# shape = (3, 2, n_mels, 32)
139-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
140134

141135
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
142136
# exactly same result. For this reason, tolerance is very relaxed here.
143-
assert torch.allclose(computed, expected, atol=1.0)
137+
torch.testing.assert_allclose(computed, expected, atol=1.0, rtol=1e-5)
144138

145139
def test_batch_compute_deltas(self):
146140
specgram = torch.randn(2, 31, 2786)
@@ -152,8 +146,7 @@ def test_batch_compute_deltas(self):
152146
computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))
153147

154148
# shape = (3, 2, 201, 1394)
155-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
156-
assert torch.allclose(computed, expected)
149+
torch.testing.assert_allclose(computed, expected)
157150

158151
def test_batch_mulaw(self):
159152
test_filepath = os.path.join(
@@ -169,8 +162,7 @@ def test_batch_mulaw(self):
169162
computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)
170163

171164
# shape = (3, 2, 201, 1394)
172-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
173-
assert torch.allclose(computed, expected)
165+
torch.testing.assert_allclose(computed, expected)
174166

175167
# Single then transform then batch
176168
waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
@@ -180,8 +172,7 @@ def test_batch_mulaw(self):
180172
computed = torchaudio.transforms.MuLawDecoding()(computed)
181173

182174
# shape = (3, 2, 201, 1394)
183-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
184-
assert torch.allclose(computed, expected)
175+
torch.testing.assert_allclose(computed, expected)
185176

186177
def test_batch_spectrogram(self):
187178
test_filepath = os.path.join(
@@ -193,9 +184,7 @@ def test_batch_spectrogram(self):
193184

194185
# Batch then transform
195186
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
196-
197-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
198-
assert torch.allclose(computed, expected)
187+
torch.testing.assert_allclose(computed, expected)
199188

200189
def test_batch_melspectrogram(self):
201190
test_filepath = os.path.join(
@@ -207,9 +196,7 @@ def test_batch_melspectrogram(self):
207196

208197
# Batch then transform
209198
computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
210-
211-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
212-
assert torch.allclose(computed, expected)
199+
torch.testing.assert_allclose(computed, expected)
213200

214201
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
215202
@AudioBackendScope("sox")
@@ -223,9 +210,7 @@ def test_batch_mfcc(self):
223210

224211
# Batch then transform
225212
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
226-
227-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
228-
assert torch.allclose(computed, expected, atol=1e-5)
213+
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
229214

230215
def test_batch_TimeStretch(self):
231216
test_filepath = os.path.join(
@@ -260,8 +245,7 @@ def test_batch_TimeStretch(self):
260245
hop_length=512,
261246
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
262247

263-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
264-
assert torch.allclose(computed, expected, atol=1e-5)
248+
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
265249

266250
def test_batch_Fade(self):
267251
test_filepath = os.path.join(
@@ -275,9 +259,7 @@ def test_batch_Fade(self):
275259

276260
# Batch then transform
277261
computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
278-
279-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
280-
assert torch.allclose(computed, expected)
262+
torch.testing.assert_allclose(computed, expected)
281263

282264
def test_batch_Vol(self):
283265
test_filepath = os.path.join(
@@ -289,9 +271,7 @@ def test_batch_Vol(self):
289271

290272
# Batch then transform
291273
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
292-
293-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
294-
assert torch.allclose(computed, expected)
274+
torch.testing.assert_allclose(computed, expected)
295275

296276

297277
if __name__ == '__main__':

test/test_compliance_kaldi.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_
7777

7878
for r in range(m):
7979
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
80-
self.assertTrue(torch.allclose(window, output))
80+
torch.testing.assert_allclose(window, output)
8181

8282
def test_get_strided(self):
8383
# generate any combination where 0 < window_size <= num_samples and
@@ -104,7 +104,7 @@ def _create_data_set(self):
104104
sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
105105
print(y >> 16)
106106
self.assertTrue(sample_rate == sr)
107-
self.assertTrue(torch.allclose(y, sound))
107+
torch.testing.assert_allclose(y, sound)
108108

109109
def _print_diagnostic(self, output, expect_output):
110110
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
@@ -156,8 +156,7 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
156156
output = get_output_fn(sound, args)
157157

158158
self._print_diagnostic(output, kaldi_output)
159-
self.assertTrue(output.shape, kaldi_output.shape)
160-
self.assertTrue(torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))
159+
torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol)
161160

162161
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
163162
@AudioBackendScope("sox")
@@ -299,7 +298,7 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
299298
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
300299
estimate = estimate[..., n_to_trim:-n_to_trim]
301300

302-
self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))
301+
torch.testing.assert_allclose(estimate, ground_truth, atol=atol, rtol=rtol)
303302

304303
def test_resample_waveform_downsample_accuracy(self):
305304
for i in range(1, 20):
@@ -324,7 +323,7 @@ def test_resample_waveform_multi_channel(self):
324323
for i in range(num_channels):
325324
single_channel = sound * (i + 1) * 1.5
326325
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
327-
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
326+
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8)
328327

329328

330329
if __name__ == '__main__':

test/test_functional.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,21 @@ def test_one_channel(self):
1616
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
1717
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
1818
computed = F.compute_deltas(specgram, win_length=3)
19-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
20-
assert torch.allclose(computed, expected)
19+
torch.testing.assert_allclose(computed, expected)
2120

2221
def test_two_channels(self):
2322
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
2423
[1.0, 2.0, 3.0, 4.0]]])
2524
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
2625
[0.5, 1.0, 1.0, 0.5]]])
2726
computed = F.compute_deltas(specgram, win_length=3)
28-
assert computed.shape == expected.shape, (computed.shape, expected.shape)
29-
assert torch.allclose(computed, expected)
27+
torch.testing.assert_allclose(computed, expected)
3028

3129

3230
def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
3331
# trim sound for case when constructed signal is shorter than original
3432
sound = sound[..., :estimate.size(-1)]
35-
36-
assert sound.shape == estimate.shape, (sound.shape, estimate.shape)
37-
assert torch.allclose(sound, estimate, atol=atol, rtol=rtol)
33+
torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol)
3834

3935

4036
def _test_istft_is_inverse_of_stft(kwargs):
@@ -308,13 +304,13 @@ def test_DB_to_amplitude(self):
308304
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
309305
x2 = F.DB_to_amplitude(db, ref, power)
310306

311-
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))
307+
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
312308

313309
# Spectrogram amplitude -> DB -> amplitude
314310
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
315311
x2 = F.DB_to_amplitude(db, ref, power)
316312

317-
self.assertTrue(torch.allclose(spec, x2, atol=5e-5))
313+
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
318314

319315
# Waveform power -> DB -> power
320316
multiplier = 10.
@@ -323,13 +319,13 @@ def test_DB_to_amplitude(self):
323319
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
324320
x2 = F.DB_to_amplitude(db, ref, power)
325321

326-
self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))
322+
torch.testing.assert_allclose(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
327323

328324
# Spectrogram power -> DB -> power
329325
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
330326
x2 = F.DB_to_amplitude(db, ref, power)
331327

332-
self.assertTrue(torch.allclose(spec, x2, atol=5e-5))
328+
torch.testing.assert_allclose(x2, spec, atol=5e-5, rtol=1e-5)
333329

334330

335331
@pytest.mark.parametrize('complex_tensor', [
@@ -341,7 +337,7 @@ def test_complex_norm(complex_tensor, power):
341337
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
342338
norm_tensor = F.complex_norm(complex_tensor, power)
343339

344-
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
340+
torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
345341

346342

347343
@pytest.mark.parametrize('specgram', [

0 commit comments

Comments
 (0)