Skip to content

Commit 02b898f

Browse files
Get rid of whitenoise and sinewave files from test (#783)
* Get rid of sine wave files and whitenoise files * Refactor integer encoding * Relax rtol from 1e-8 to 1e-7 for compliance kaldi * relax waveform multi channel resample atol to 1e-7 from 1e-8 * relax tolerance for length consistency for speed effect Co-authored-by: moto <[email protected]>
1 parent 8181a83 commit 02b898f

File tree

9 files changed

+90
-72
lines changed

9 files changed

+90
-72
lines changed
-434 KB
Binary file not shown.
-434 KB
Binary file not shown.

test/assets/whitenoise.mp3

-55.8 KB
Binary file not shown.

test/assets/whitenoise.wav

-431 KB
Binary file not shown.

test/common_utils/data_utils.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@ def get_asset_path(*paths):
1313
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)
1414

1515

16+
def convert_tensor_encoding(
17+
tensor: torch.tensor,
18+
dtype: torch.dtype,
19+
):
20+
"""Convert input tensor with values between -1 and 1 to integer encoding
21+
Args:
22+
tensor: input tensor, assumed between -1 and 1
23+
dtype: desired output tensor dtype
24+
Returns:
25+
Tensor: shape of (n_channels, sample_rate * duration)
26+
"""
27+
if dtype == torch.int32:
28+
tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648
29+
if dtype == torch.int16:
30+
tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768
31+
if dtype == torch.uint8:
32+
tensor *= (tensor > 0) * 127 + (tensor < 0) * 128
33+
tensor += 128
34+
tensor = tensor.to(dtype)
35+
return tensor
36+
37+
1638
def get_whitenoise(
1739
*,
1840
sample_rate: int = 16000,
@@ -43,25 +65,17 @@ def get_whitenoise(
4365
if dtype not in [torch.float32, torch.int32, torch.int16, torch.uint8]:
4466
raise NotImplementedError(f'dtype {dtype} is not supported.')
4567
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
46-
# so we only folk on CPU, generate values and move the data to the given device
68+
# so we only fork on CPU, generate values and move the data to the given device
4769
with torch.random.fork_rng([]):
4870
torch.random.manual_seed(seed)
49-
tensor = torch.randn([sample_rate * duration], dtype=torch.float32, device='cpu')
71+
tensor = torch.randn([int(sample_rate * duration)], dtype=torch.float32, device='cpu')
5072
tensor /= 2.0
5173
tensor *= scale_factor
5274
tensor.clamp_(-1.0, 1.0)
53-
if dtype == torch.int32:
54-
tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648
55-
if dtype == torch.int16:
56-
tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768
57-
if dtype == torch.uint8:
58-
tensor *= (tensor > 0) * 127 + (tensor < 0) * 128
59-
tensor += 128
60-
tensor = tensor.to(dtype)
6175
tensor = tensor.repeat([n_channels, 1])
6276
if not channels_first:
6377
tensor = tensor.t()
64-
return tensor.to(device=device)
78+
return convert_tensor_encoding(tensor, dtype)
6579

6680

6781
def get_sinusoid(
@@ -91,8 +105,8 @@ def get_sinusoid(
91105
dtype = getattr(torch, dtype)
92106
pie2 = 2 * 3.141592653589793
93107
end = pie2 * frequency * duration
94-
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
95-
sin = torch.sin(theta, out=None).repeat([n_channels, 1])
108+
theta = torch.linspace(0, end, int(sample_rate * duration), dtype=torch.float32, device=device)
109+
tensor = torch.sin(theta, out=None).repeat([n_channels, 1])
96110
if not channels_first:
97-
sin = sin.t()
98-
return sin
111+
tensor = tensor.t()
112+
return convert_tensor_encoding(tensor, dtype)

test/functional_cpu_test.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torchaudio
66
import torchaudio.functional as F
7+
from parameterized import parameterized
78
import pytest
89

910
from . import common_utils
@@ -299,24 +300,18 @@ def test_linearity_of_istft4(self):
299300

300301

301302
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
302-
def test_pitch(self):
303-
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
304-
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
305-
306-
# Files from https://www.mediacollege.com/audio/tone/download/
307-
tests = [
308-
(test_filepath_100, 100),
309-
(test_filepath_440, 440),
310-
]
311-
312-
for filename, freq_ref in tests:
313-
waveform, sample_rate = common_utils.load_wav(filename)
314-
315-
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
316-
317-
threshold = 1
318-
s = ((freq - freq_ref).abs() > threshold).sum()
319-
self.assertFalse(s)
303+
@parameterized.expand([(100,), (440,)])
304+
def test_pitch(self, frequency):
305+
sample_rate = 44100
306+
test_sine_waveform = common_utils.get_sinusoid(
307+
frequency=frequency, sample_rate=sample_rate, duration=5,
308+
)
309+
310+
freq = torchaudio.functional.detect_pitch_frequency(test_sine_waveform, sample_rate)
311+
312+
threshold = 1
313+
s = ((freq - frequency).abs() > threshold).sum()
314+
self.assertFalse(s)
320315

321316

322317
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):

test/test_batch_consistency.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test numerical consistency among single input and batched input."""
22
import unittest
3+
import itertools
4+
from parameterized import parameterized
35

46
import torch
57
import torchaudio
@@ -47,17 +49,15 @@ def test_griffinlim(self):
4749
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
4850
)
4951

50-
def test_detect_pitch_frequency(self):
51-
filenames = [
52-
'steam-train-whistle-daniel_simon.wav', # 2ch 44100Hz
53-
# Files from https://www.mediacollege.com/audio/tone/download/
54-
'100Hz_44100Hz_16bit_05sec.wav', # 1ch
55-
'440Hz_44100Hz_16bit_05sec.wav', # 1ch
56-
]
57-
for filename in filenames:
58-
filepath = common_utils.get_asset_path(filename)
59-
waveform, sample_rate = torchaudio.load(filepath)
60-
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
52+
@parameterized.expand(list(itertools.product(
53+
[100, 440],
54+
[8000, 16000, 44100],
55+
[1, 2],
56+
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
57+
def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
58+
waveform = common_utils.get_sinusoid(frequency=frequency, sample_rate=sample_rate,
59+
n_channels=n_channels, duration=5)
60+
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
6161

6262
def test_istft(self):
6363
stft = torch.tensor([
@@ -80,8 +80,10 @@ def test_overdrive(self):
8080
self.assert_batch_consistencies(F.overdrive, waveform, gain=45, colour=30)
8181

8282
def test_phaser(self):
83-
filepath = common_utils.get_asset_path("whitenoise.wav")
84-
waveform, sample_rate = torchaudio.load(filepath)
83+
sample_rate = 44100
84+
waveform = common_utils.get_whitenoise(
85+
sample_rate=sample_rate, duration=5,
86+
)
8587
self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
8688

8789
def test_flanger(self):

test/test_compliance_kaldi.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,25 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
4747

4848

4949
@common_utils.skipIfNoSoxBackend
50-
class Test_Kaldi(common_utils.TorchaudioTestCase):
50+
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
5151
backend = 'sox'
5252

53-
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
54-
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
5553
kaldi_output_dir = common_utils.get_asset_path('kaldi')
54+
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
5655
test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}
5756

57+
def setUp(self):
58+
super().setUp()
59+
60+
# 1. test signal for testing resampling
61+
self.test1_signal_sr = 16000
62+
self.test1_signal = common_utils.get_whitenoise(
63+
sample_rate=self.test1_signal_sr, duration=0.5,
64+
)
65+
66+
# 2. test audio file corresponding to saved kaldi ark files
67+
self.test2_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
68+
5869
# separating test files by their types (e.g 'spec', 'fbank', etc.)
5970
for f in os.listdir(kaldi_output_dir):
6071
dash_idx = f.find('-')
@@ -94,7 +105,6 @@ def test_get_strided(self):
94105

95106
def _create_data_set(self):
96107
# used to generate the dataset to test on. this is not used in testing (offline procedure)
97-
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
98108
sr = 16000
99109
x = torch.arange(0, 20).float()
100110
# between [-6,6]
@@ -103,8 +113,8 @@ def _create_data_set(self):
103113
y = (y / 6 * (1 << 30)).long()
104114
# clear the last 16 bits because they aren't used anyways
105115
y = ((y >> 16) << 16).float()
106-
torchaudio.save(test_filepath, y, sr)
107-
sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
116+
torchaudio.save(self.test_filepath, y, sr)
117+
sound, sample_rate = torchaudio.load(self.test_filepath, normalization=False)
108118
print(y >> 16)
109119
self.assertTrue(sample_rate == sr)
110120
torch.testing.assert_allclose(y, sound)
@@ -123,7 +133,7 @@ def _print_diagnostic(self, output, expect_output):
123133
print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())
124134

125135
def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
126-
expected_num_args, get_output_fn, atol=1e-5, rtol=1e-8):
136+
expected_num_args, get_output_fn, atol=1e-5, rtol=1e-7):
127137
"""
128138
Inputs:
129139
sound_filepath (str): The location of the sound file
@@ -135,7 +145,7 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
135145
atol (float): absolute tolerance
136146
rtol (float): relative tolerance
137147
"""
138-
sound, sample_rate = torchaudio.load_wav(sound_filepath)
148+
sound, sr = torchaudio.load_wav(sound_filepath)
139149
files = self.test_filepaths[filepath_key]
140150

141151
assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
@@ -170,22 +180,19 @@ def get_output_fn(sound, args):
170180
output = kaldi.resample_waveform(sound, args[1], args[2])
171181
return output
172182

173-
self._compliance_test_helper(self.test_8000_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
183+
self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
174184

175185
def test_resample_waveform_upsample_size(self):
176-
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
177-
upsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
178-
self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
186+
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
187+
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
179188

180189
def test_resample_waveform_downsample_size(self):
181-
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
182-
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate // 2)
183-
self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)
190+
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
191+
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)
184192

185193
def test_resample_waveform_identity_size(self):
186-
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
187-
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate)
188-
self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
194+
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
195+
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))
189196

190197
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
191198
atol=1e-1, rtol=1e-4):
@@ -226,19 +233,19 @@ def test_resample_waveform_upsample_accuracy(self):
226233
def test_resample_waveform_multi_channel(self):
227234
num_channels = 3
228235

229-
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath) # (1, 8000)
230-
multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000)
236+
multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
231237

232238
for i in range(num_channels):
233239
multi_sound[i, :] *= (i + 1) * 1.5
234240

235-
multi_sound_sampled = kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2)
241+
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
236242

237243
# check that sampling is same whether using separately or in a tensor of size (c, n)
238244
for i in range(num_channels):
239-
single_channel = sound * (i + 1) * 1.5
240-
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
241-
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8)
245+
single_channel = self.test1_signal * (i + 1) * 1.5
246+
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
247+
self.test1_signal_sr // 2)
248+
torch.testing.assert_allclose(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
242249

243250

244251
if __name__ == '__main__':

test/test_sox_effects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test_lowpass_speed(self):
4545
E.append_effect_to_chain("speed", speed)
4646
E.append_effect_to_chain("rate", si.rate)
4747
x, sr = E.sox_build_flow_effects()
48-
# check if effects worked
49-
self.assertEqual(x.size(1), int((si.length / si.channels) / speed))
48+
# check if effects worked, add small tolerance for rounding effects
49+
self.assertEqual(x.size(1), int((si.length / si.channels) / speed), atol=1, rtol=1e-8)
5050

5151
def test_ulaw_and_siginfo(self):
5252
si_out = torchaudio.sox_signalinfo_t()

0 commit comments

Comments
 (0)