Skip to content

Commit e88aba5

Browse files
committed
Add save function
1 parent 793eeab commit e88aba5

File tree

10 files changed

+713
-8
lines changed

10 files changed

+713
-8
lines changed

test/sox_io_backend/sox_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def gen_audio_file(
3232
command = [
3333
'sox',
3434
'-V', # verbose
35+
]
36+
if bit_depth is not None:
37+
command += ['--bits', str(bit_depth)]
38+
command += [
3539
'--rate', str(sample_rate),
3640
'--null', # no input
3741
'--channels', str(num_channels),
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
)
16+
17+
18+
@skipIfNoExec('sox')
19+
@skipIfNoExtension
20+
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
21+
"""save/load round trip should not degrade data for lossless formats"""
22+
@parameterized.expand(list(itertools.product(
23+
['float32', 'int32', 'int16', 'uint8'],
24+
[8000, 16000],
25+
[1, 2],
26+
[False, True]
27+
)), name_func=get_test_name)
28+
def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize):
29+
"""save/load round trip should not degrade data for wav formats"""
30+
original = get_wav_data(dtype, num_channels, normalize=normalize)
31+
data = original
32+
for i in range(10):
33+
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav')
34+
sox_io_backend.save(path, data, sample_rate)
35+
data, sr = sox_io_backend.load(path, normalize=normalize)
36+
assert sr == sample_rate
37+
self.assertEqual(original, data)
38+
39+
@parameterized.expand(list(itertools.product(
40+
[8000, 16000],
41+
[1, 2],
42+
list(range(9)),
43+
)), name_func=get_test_name)
44+
def test_roundtrip_flac(self, sample_rate, num_channels, compression_level):
45+
"""save/load round trip should not degrade data for flac formats"""
46+
original = get_wav_data('float32', num_channels)
47+
data = original
48+
for i in range(10):
49+
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac')
50+
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
51+
data, sr = sox_io_backend.load(path)
52+
assert sr == sample_rate
53+
self.assertEqual(original, data)

test/sox_io_backend/test_save.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
load_wav,
16+
save_wav,
17+
)
18+
from . import sox_utils
19+
20+
21+
class SaveTestBase(TempDirMixin, PytorchTestCase):
22+
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
23+
"""`sox_io_backend.save` can save wav format."""
24+
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav')
25+
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
26+
sox_io_backend.save(path, expected, sample_rate)
27+
found = load_wav(path)[0]
28+
self.assertEqual(found, expected)
29+
30+
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
31+
"""`sox_io_backend.save` can save mp3 format.
32+
33+
mp3 encoding introduces delay and boundary effects so
34+
we convert the resulting mp3 to wav and compare the results there
35+
36+
|
37+
| 1. Generate original wav file with SciPy
38+
|
39+
v
40+
-------------- wav ----------------
41+
| |
42+
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
43+
| then save with torchaudio |
44+
v v
45+
mp3 mp3
46+
| |
47+
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox
48+
| |
49+
v v
50+
wav wav
51+
| |
52+
| 2.3. load with scipy | 3.3. load with scipy
53+
| |
54+
v v
55+
tensor -------> compare <--------- tensor
56+
57+
"""
58+
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav')
59+
mp3_path = f'{src_path}.mp3'
60+
wav_path = f'{mp3_path}.wav'
61+
mp3_path_sox = f'{src_path}.sox.mp3'
62+
wav_path_sox = f'{mp3_path_sox}.wav'
63+
64+
# 1. Generate original wav
65+
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
66+
save_wav(src_path, data, sample_rate)
67+
# 2.1. Convert the original wav to mp3 with torchaudio
68+
sox_io_backend.save(
69+
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
70+
# 2.2. Convert the mp3 to wav with Sox
71+
sox_utils.convert_audio_file(mp3_path, wav_path)
72+
# 2.3. Load
73+
found = load_wav(wav_path)[0]
74+
75+
# 3.1. Convert the original wav to mp3 with SoX
76+
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
77+
# 3.2. Convert the mp3 to wav with Sox
78+
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
79+
# 3.3. Load
80+
expected = load_wav(wav_path_sox)[0]
81+
82+
self.assertEqual(found, expected)
83+
84+
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
85+
"""`sox_io_backend.save` can save flac format.
86+
87+
This test takes the same strategy as mp3 to compare the result
88+
"""
89+
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav')
90+
flac_path = f'{src_path}.flac'
91+
wav_path = f'{flac_path}.wav'
92+
flac_path_sox = f'{src_path}.sox.flac'
93+
wav_path_sox = f'{flac_path_sox}.wav'
94+
95+
# 1. Generate original wav
96+
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
97+
save_wav(src_path, data, sample_rate)
98+
# 2.1. Convert the original wav to flac with torchaudio
99+
sox_io_backend.save(
100+
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
101+
# 2.2. Convert the flac to wav with Sox
102+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
103+
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32)
104+
# 2.3. Load
105+
found = load_wav(wav_path)[0]
106+
107+
# 3.1. Convert the original wav to flac with SoX
108+
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level)
109+
# 3.2. Convert the flac to wav with Sox
110+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
111+
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32)
112+
# 3.3. Load
113+
expected = load_wav(wav_path_sox)[0]
114+
115+
self.assertEqual(found, expected)
116+
117+
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
118+
"""`sox_io_backend.save` can save vorbis format.
119+
120+
This test takes the same strategy as mp3 to compare the result
121+
"""
122+
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav')
123+
vorbis_path = f'{src_path}.vorbis'
124+
wav_path = f'{vorbis_path}.wav'
125+
vorbis_path_sox = f'{src_path}.sox.vorbis'
126+
wav_path_sox = f'{vorbis_path_sox}.wav'
127+
128+
# 1. Generate original wav
129+
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
130+
save_wav(src_path, data, sample_rate)
131+
# 2.1. Convert the original wav to vorbis with torchaudio
132+
sox_io_backend.save(
133+
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
134+
# 2.2. Convert the vorbis to wav with Sox
135+
sox_utils.convert_audio_file(vorbis_path, wav_path)
136+
# 2.3. Load
137+
found = load_wav(wav_path)[0]
138+
139+
# 3.1. Convert the original wav to vorbis with SoX
140+
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level)
141+
# 3.2. Convert the vorbis to wav with Sox
142+
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox)
143+
# 3.3. Load
144+
expected = load_wav(wav_path_sox)[0]
145+
146+
# sox's vorbis encoding has some random boundary effect, which cause small number of
147+
# samples yields higher descrepency than the others.
148+
# so we allow small portions of data to be outside of absolute torelance.
149+
# make sure to pass somewhat long duration
150+
atol = 1.0e-4
151+
max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol.
152+
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel()
153+
if failure_ratio > max_failure_allowed:
154+
# it's failed and this will give a better error message.
155+
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
156+
157+
def assert_vorbis(self, *args, **kwargs):
158+
# sox's vorbis encoding has some randomness, so we run tests multiple time
159+
max_retry = 5
160+
error = None
161+
for _ in range(max_retry):
162+
try:
163+
self._assert_vorbis(*args, **kwargs)
164+
break
165+
except AssertionError as e:
166+
error = e
167+
else:
168+
raise error
169+
170+
171+
@skipIfNoExec('sox')
172+
@skipIfNoExtension
173+
class TestSave(SaveTestBase):
174+
@parameterized.expand(list(itertools.product(
175+
['float32', 'int32', 'int16', 'uint8'],
176+
[8000, 16000],
177+
[1, 2],
178+
)), name_func=get_test_name)
179+
def test_wav(self, dtype, sample_rate, num_channels):
180+
"""`sox_io_backend.save` can save wav format."""
181+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
182+
183+
@parameterized.expand(list(itertools.product(
184+
['float32'],
185+
[16000],
186+
[2],
187+
)), name_func=get_test_name)
188+
def test_wav_large(self, dtype, sample_rate, num_channels):
189+
"""`sox_io_backend.save` can save large wav file."""
190+
two_hours = 2 * 60 * 60 * sample_rate
191+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours)
192+
193+
@parameterized.expand(list(itertools.product(
194+
['float32', 'int32', 'int16', 'uint8'],
195+
[4, 8, 16, 32],
196+
)), name_func=get_test_name)
197+
def test_multiple_channels(self, dtype, num_channels):
198+
"""`sox_io_backend.save` can save wav with more than 2 channels."""
199+
sample_rate = 8000
200+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
201+
202+
@parameterized.expand(list(itertools.product(
203+
[8000, 16000],
204+
[1, 2],
205+
[96, 128, 160, 192, 224, 256, 320],
206+
)), name_func=get_test_name)
207+
def test_mp3(self, sample_rate, num_channels, bit_rate):
208+
"""`sox_io_backend.save` can save mp3 format."""
209+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
210+
211+
@parameterized.expand(list(itertools.product(
212+
[16000],
213+
[2],
214+
[128],
215+
)), name_func=get_test_name)
216+
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
217+
"""`sox_io_backend.save` can save large mp3 file."""
218+
two_hours = 2 * 60 * 60
219+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
220+
221+
@parameterized.expand(list(itertools.product(
222+
[8000, 16000],
223+
[1, 2],
224+
list(range(9)),
225+
)), name_func=get_test_name)
226+
def test_flac(self, sample_rate, num_channels, compression_level):
227+
"""`sox_io_backend.save` can save flac format."""
228+
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
229+
230+
@parameterized.expand(list(itertools.product(
231+
[16000],
232+
[2],
233+
[0],
234+
)), name_func=get_test_name)
235+
def test_flac_large(self, sample_rate, num_channels, compression_level):
236+
"""`sox_io_backend.save` can save large flac file."""
237+
two_hours = 2 * 60 * 60
238+
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
239+
240+
@parameterized.expand(list(itertools.product(
241+
[8000, 16000],
242+
[1, 2],
243+
[-1, 0, 1, 2, 3, 3.6, 5, 10],
244+
)), name_func=get_test_name)
245+
def test_vorbis(self, sample_rate, num_channels, quality_level):
246+
"""`sox_io_backend.save` can save vorbis format."""
247+
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20)
248+
249+
# note: torchaudio can load large vorbis file, but cannot save large volbis file
250+
# the following test causes Segmentation fault
251+
#
252+
'''
253+
@parameterized.expand(list(itertools.product(
254+
[16000],
255+
[2],
256+
[10],
257+
)), name_func=get_test_name)
258+
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
259+
"""`sox_io_backend.save` can save large vorbis file correctly."""
260+
two_hours = 2 * 60 * 60
261+
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
262+
'''
263+
264+
265+
@skipIfNoExec('sox')
266+
@skipIfNoExtension
267+
class TestSaveParams(TempDirMixin, PytorchTestCase):
268+
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
269+
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
270+
def test_channels_first(self, channels_first):
271+
"""channels_first swaps axes"""
272+
path = self.get_temp_path('test_channel_first_{channels_first}.wav')
273+
data = get_wav_data('int32', 2, channels_first=channels_first)
274+
sox_io_backend.save(
275+
path, data, 8000, channels_first=channels_first)
276+
found = load_wav(path)[0]
277+
expected = data if channels_first else data.transpose(1, 0)
278+
self.assertEqual(found, expected)
279+
280+
@parameterized.expand([
281+
'float32', 'int32', 'int16', 'uint8'
282+
], name_func=get_test_name)
283+
def test_noncontiguous(self, dtype):
284+
"""Noncontiguous tensors are saved correctly"""
285+
path = self.get_temp_path('test_uncontiguous_{dtype}.wav')
286+
expected = get_wav_data(dtype, 4)[::2, ::2]
287+
assert not expected.is_contiguous()
288+
sox_io_backend.save(path, expected, 8000)
289+
found = load_wav(path)[0]
290+
self.assertEqual(found, expected)
291+
292+
@parameterized.expand([
293+
'float32', 'int32', 'int16', 'uint8',
294+
])
295+
def test_tensor_preserve(self, dtype):
296+
"""save function should not alter Tensor"""
297+
path = self.get_temp_path(f'test_preserve_{dtype}.wav')
298+
expected = get_wav_data(dtype, 4)[::2, ::2]
299+
300+
data = expected.clone()
301+
sox_io_backend.save(path, data, 8000)
302+
303+
self.assertEqual(data, expected)

0 commit comments

Comments
 (0)