Skip to content

Commit e96727e

Browse files
committed
Add save function
1 parent f8ca606 commit e96727e

File tree

8 files changed

+662
-1
lines changed

8 files changed

+662
-1
lines changed

test/sox_io_backend/sox_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def gen_audio_file(
2929
command = [
3030
'sox',
3131
'-V', # verbose
32+
]
33+
if bit_depth is not None:
34+
command += ['--bits', str(bit_depth)]
35+
command += [
3236
'--rate', str(sample_rate),
3337
'--null', # no input
3438
'--channels', str(num_channels),
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
)
16+
from . import sox_utils
17+
18+
19+
@skipIfNoExec('sox')
20+
@skipIfNoExtension
21+
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
22+
"""save/load round trip should not degrade data for lossless formats"""
23+
@parameterized.expand(list(itertools.product(
24+
['float32', 'int32', 'int16', 'uint8'],
25+
[8000, 16000],
26+
[1, 2],
27+
[False, True]
28+
)), name_func=get_test_name)
29+
def test_roundtrip_wav(self, dtype, sample_rate, num_channels, normalize):
30+
"""save/load round trip should not degrade data for wav formats"""
31+
original = get_wav_data(dtype, num_channels, normalize=normalize)
32+
data = original
33+
for i in range(10):
34+
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}_{i}.wav')
35+
sox_io_backend.save(path, data, sample_rate)
36+
data, sr = sox_io_backend.load(path, normalize=normalize)
37+
assert sr == sample_rate
38+
self.assertEqual(original, data)
39+
40+
@parameterized.expand(list(itertools.product(
41+
[8000, 16000],
42+
[1, 2],
43+
list(range(9)),
44+
)), name_func=get_test_name)
45+
def test_roundtrip_flac(self, sample_rate, num_channels, compression_level):
46+
"""save/load round trip should not degrade data for flac formats"""
47+
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
48+
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level)
49+
original = sox_io_backend.load(path)[0]
50+
51+
data = original
52+
for i in range(10):
53+
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{i}.flac')
54+
sox_io_backend.save(path, data, sample_rate)
55+
data, sr = sox_io_backend.load(path)
56+
assert sr == sample_rate
57+
self.assertEqual(original, data)

test/sox_io_backend/test_save.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
import itertools
2+
3+
from torchaudio.backend import sox_io_backend
4+
from parameterized import parameterized
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
PytorchTestCase,
9+
skipIfNoExec,
10+
skipIfNoExtension,
11+
)
12+
from .common import (
13+
get_test_name,
14+
get_wav_data,
15+
load_wav,
16+
)
17+
from . import sox_utils
18+
19+
20+
class SaveTestBase(TempDirMixin, PytorchTestCase):
21+
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
22+
"""`sox_io_backend.save` can save wav format."""
23+
path = self.get_temp_path(f'test_wav_{dtype}_{sample_rate}_{num_channels}.wav')
24+
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
25+
sox_io_backend.save(path, expected, sample_rate)
26+
found = load_wav(path)[0]
27+
self.assertEqual(found, expected)
28+
29+
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
30+
"""`sox_io_backend.save` can save mp3 format.
31+
32+
mp3 encoding introduces delay and boundary effects so
33+
we convert the resulting mp3 to wav and compare the results there
34+
35+
|
36+
| 1. Generate original wav with Sox
37+
|
38+
v
39+
-------------- wav ----------------
40+
| |
41+
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
42+
| then save with torchaudio |
43+
v v
44+
mp3 mp3
45+
| |
46+
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox
47+
| |
48+
v v
49+
wav wav
50+
| |
51+
| 2.3. load with scipy | 3.3. load with scipy
52+
| |
53+
v v
54+
tensor -------> compare <--------- tensor
55+
56+
"""
57+
src_path = self.get_temp_path(f'test_mp3_{sample_rate}_{num_channels}_{bit_rate}_{duration}.wav')
58+
mp3_path = f'{src_path}.mp3'
59+
wav_path = f'{mp3_path}.wav'
60+
mp3_path_sox = f'{src_path}.sox.mp3'
61+
wav_path_sox = f'{mp3_path_sox}.wav'
62+
63+
# 1. Generate original wav
64+
sox_utils.gen_audio_file(
65+
src_path, sample_rate, num_channels,
66+
bit_depth=32,
67+
encoding='floating-point',
68+
duration=duration,
69+
)
70+
# 2.1. Convert the original wav to mp3 with torchaudio
71+
sox_io_backend.save(
72+
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
73+
# 2.2. Convert the mp3 to wav with Sox
74+
sox_utils.convert_audio_file(mp3_path, wav_path)
75+
# 2.3. Load
76+
found = load_wav(wav_path)[0]
77+
78+
# 3.1. Convert the original wav to mp3 with SoX
79+
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
80+
# 3.2. Convert the mp3 to wav with Sox
81+
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
82+
# 3.3. Load
83+
expected = load_wav(wav_path_sox)[0]
84+
85+
self.assertEqual(found, expected)
86+
87+
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
88+
"""`sox_io_backend.save` can save flac format.
89+
90+
This test takes the same strategy as mp3 to compare the result
91+
"""
92+
src_path = self.get_temp_path(f'test_flac_{sample_rate}_{num_channels}_{compression_level}_{duration}.wav')
93+
flac_path = f'{src_path}.flac'
94+
wav_path = f'{flac_path}.wav'
95+
flac_path_sox = f'{src_path}.sox.flac'
96+
wav_path_sox = f'{flac_path_sox}.wav'
97+
98+
# 1. Generate original wav
99+
sox_utils.gen_audio_file(
100+
src_path, sample_rate, num_channels,
101+
bit_depth=32,
102+
encoding='floating-point',
103+
duration=duration,
104+
)
105+
# 2.1. Convert the original wav to flac with torchaudio
106+
sox_io_backend.save(
107+
flac_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
108+
# 2.2. Convert the flac to wav with Sox
109+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
110+
sox_utils.convert_audio_file(flac_path, wav_path, bit_depth=32)
111+
# 2.3. Load
112+
found = load_wav(wav_path)[0]
113+
114+
# 3.1. Convert the original wav to flac with SoX
115+
sox_utils.convert_audio_file(src_path, flac_path_sox, compression=compression_level)
116+
# 3.2. Convert the flac to wav with Sox
117+
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
118+
sox_utils.convert_audio_file(flac_path_sox, wav_path_sox, bit_depth=32)
119+
# 3.3. Load
120+
expected = load_wav(wav_path_sox)[0]
121+
122+
self.assertEqual(found, expected)
123+
124+
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
125+
"""`sox_io_backend.save` can save vorbis format.
126+
127+
This test takes the same strategy as mp3 to compare the result
128+
"""
129+
src_path = self.get_temp_path(f'test_vorbis_{sample_rate}_{num_channels}_{quality_level}_{duration}.wav')
130+
vorbis_path = f'{src_path}.vorbis'
131+
wav_path = f'{vorbis_path}.wav'
132+
vorbis_path_sox = f'{src_path}.sox.vorbis'
133+
wav_path_sox = f'{vorbis_path_sox}.wav'
134+
135+
# 1. Generate original wav
136+
sox_utils.gen_audio_file(
137+
src_path, sample_rate, num_channels,
138+
bit_depth=16,
139+
encoding='signed-integer',
140+
duration=duration,
141+
)
142+
# 2.1. Convert the original wav to vorbis with torchaudio
143+
sox_io_backend.save(
144+
vorbis_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
145+
# 2.2. Convert the vorbis to wav with Sox
146+
sox_utils.convert_audio_file(vorbis_path, wav_path)
147+
# 2.3. Load
148+
found = load_wav(wav_path)[0]
149+
150+
# 3.1. Convert the original wav to vorbis with SoX
151+
sox_utils.convert_audio_file(src_path, vorbis_path_sox, compression=quality_level)
152+
# 3.2. Convert the vorbis to wav with Sox
153+
sox_utils.convert_audio_file(vorbis_path_sox, wav_path_sox)
154+
# 3.3. Load
155+
expected = load_wav(wav_path_sox)[0]
156+
157+
# vorbis encoding seems to have some randomness,
158+
# so setting absolute torelance a bit higher
159+
self.assertEqual(found, expected, rtol=1.3e-06, atol=1e-03)
160+
161+
162+
@skipIfNoExec('sox')
163+
@skipIfNoExtension
164+
class TestSave(SaveTestBase):
165+
@parameterized.expand(list(itertools.product(
166+
['float32', 'int32', 'int16', 'uint8'],
167+
[8000, 16000],
168+
[1, 2],
169+
)), name_func=get_test_name)
170+
def test_wav(self, dtype, sample_rate, num_channels):
171+
"""`sox_io_backend.save` can save wav format."""
172+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
173+
174+
@parameterized.expand(list(itertools.product(
175+
['float32'],
176+
[16000],
177+
[2],
178+
)), name_func=get_test_name)
179+
def test_wav_large(self, dtype, sample_rate, num_channels):
180+
"""`sox_io_backend.save` can save large wav file."""
181+
two_hours = 2 * 60 * 60 * sample_rate
182+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours)
183+
184+
@parameterized.expand(list(itertools.product(
185+
['float32', 'int32', 'int16', 'uint8'],
186+
[4, 8, 16, 32],
187+
)), name_func=get_test_name)
188+
def test_multiple_channels(self, dtype, num_channels):
189+
"""`sox_io_backend.save` can save wav with more than 2 channels."""
190+
sample_rate = 8000
191+
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
192+
193+
@parameterized.expand(list(itertools.product(
194+
[8000, 16000],
195+
[1, 2],
196+
[96, 128, 160, 192, 224, 256, 320],
197+
)), name_func=get_test_name)
198+
def test_mp3(self, sample_rate, num_channels, bit_rate):
199+
"""`sox_io_backend.save` can save mp3 format."""
200+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
201+
202+
@parameterized.expand(list(itertools.product(
203+
[16000],
204+
[2],
205+
[128],
206+
)), name_func=get_test_name)
207+
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
208+
"""`sox_io_backend.save` can save large mp3 file."""
209+
two_hours = 2 * 60 * 60
210+
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
211+
212+
@parameterized.expand(list(itertools.product(
213+
[8000, 16000],
214+
[1, 2],
215+
list(range(9)),
216+
)), name_func=get_test_name)
217+
def test_flac(self, sample_rate, num_channels, compression_level):
218+
"""`sox_io_backend.save` can save flac format."""
219+
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
220+
221+
@parameterized.expand(list(itertools.product(
222+
[16000],
223+
[2],
224+
[0],
225+
)), name_func=get_test_name)
226+
def test_flac_large(self, sample_rate, num_channels, compression_level):
227+
"""`sox_io_backend.save` can save large flac file."""
228+
two_hours = 2 * 60 * 60
229+
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
230+
231+
@parameterized.expand(list(itertools.product(
232+
[8000, 16000],
233+
[-1, 0, 1, 2, 3, 3.6, 5, 10],
234+
)), name_func=get_test_name)
235+
def test_vorbis_mono(self, sample_rate, quality_level):
236+
"""`sox_io_backend.save` can save vorbis format."""
237+
num_channels = 1
238+
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
239+
240+
@parameterized.expand(list(itertools.product(
241+
[8000],
242+
[-1, 0, 1, 2, 3, 3.6, 5, 10],
243+
)), name_func=get_test_name)
244+
def test_vorbis_stereo(self, sample_rate, quality_level):
245+
"""`sox_io_backend.save` can save vorbis format."""
246+
# note: though sample_rate16000 mostly works fine, it is omitted because
247+
# it gives slighly higher descrepency for few samples.
248+
# such as 762 out of 32000 not within atol=0.0001.
249+
num_channels = 2
250+
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
251+
252+
# note: torchaudio can load large vorbis file, but cannot save large volbis file
253+
# the following test causes Segmentation fault
254+
#
255+
'''
256+
@parameterized.expand(list(itertools.product(
257+
[16000],
258+
[2],
259+
[10],
260+
)), name_func=get_test_name)
261+
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
262+
"""`sox_io_backend.save` can save large vorbis file correctly."""
263+
two_hours = 2 * 60 * 60
264+
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
265+
'''
266+
267+
268+
@skipIfNoExec('sox')
269+
@skipIfNoExtension
270+
class TestSaveParams(TempDirMixin, PytorchTestCase):
271+
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
272+
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
273+
def test_channels_first(self, channels_first):
274+
"""channels_first swaps axes"""
275+
path = self.get_temp_path('test_channel_first_{channels_first}.wav')
276+
data = get_wav_data('int32', 2, channels_first=channels_first)
277+
sox_io_backend.save(
278+
path, data, 8000, channels_first=channels_first)
279+
found = load_wav(path)[0]
280+
expected = data if channels_first else data.transpose(1, 0)
281+
self.assertEqual(found, expected)
282+
283+
@parameterized.expand([
284+
'float32', 'int32', 'int16', 'uint8'
285+
], name_func=get_test_name)
286+
def test_noncontiguous(self, dtype):
287+
"""Noncontiguous tensors are saved correctly"""
288+
path = self.get_temp_path('test_uncontiguous_{dtype}.wav')
289+
expected = get_wav_data(dtype, 4)[::2, ::2]
290+
assert not expected.is_contiguous()
291+
sox_io_backend.save(path, expected, 8000)
292+
found = load_wav(path)[0]
293+
self.assertEqual(found, expected)
294+
295+
@parameterized.expand([
296+
'float32', 'int32', 'int16', 'uint8',
297+
])
298+
def test_tensor_preserve(self, dtype):
299+
"""save function should not alter Tensor"""
300+
path = self.get_temp_path(f'test_preserve_{dtype}.wav')
301+
expected = get_wav_data(dtype, 4)[::2, ::2]
302+
303+
data = expected.clone()
304+
sox_io_backend.save(path, data, 8000)
305+
306+
self.assertEqual(data, expected)

0 commit comments

Comments
 (0)