Skip to content

Commit a6efd49

Browse files
authored
Add SPHERE format support (#871)
1 parent 2a6b6b5 commit a6efd49

File tree

5 files changed

+93
-1
lines changed

5 files changed

+93
-1
lines changed

test/torchaudio_unittest/sox_io_backend/info_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
108108
assert info.num_frames == sample_rate * duration
109109
assert info.num_channels == num_channels
110110

111+
@parameterized.expand(list(itertools.product(
112+
[8000, 16000],
113+
[1, 2],
114+
)), name_func=name_func)
115+
def test_sphere(self, sample_rate, num_channels):
116+
"""`sox_io_backend.info` can check sph file correctly"""
117+
duration = 1
118+
path = self.get_temp_path('data.sph')
119+
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration)
120+
info = sox_io_backend.info(path)
121+
assert info.sample_rate == sample_rate
122+
assert info.num_frames == sample_rate * duration
123+
assert info.num_channels == num_channels
124+
111125

112126
@skipIfNoExtension
113127
class TestInfoOpus(PytorchTestCase):

test/torchaudio_unittest/sox_io_backend/load_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,28 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
120120
assert sr == sample_rate
121121
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
122122

123+
def assert_sphere(self, sample_rate, num_channels, duration):
124+
"""`sox_io_backend.load` can load sph format.
125+
126+
This test takes the same strategy as mp3 to compare the result
127+
"""
128+
path = self.get_temp_path('1.original.sph')
129+
ref_path = self.get_temp_path('2.reference.wav')
130+
131+
# 1. Generate sph with sox
132+
sox_utils.gen_audio_file(
133+
path, sample_rate, num_channels,
134+
bit_depth=32, duration=duration)
135+
# 2. Convert to wav with sox
136+
sox_utils.convert_audio_file(path, ref_path)
137+
# 3. Load sph with torchaudio
138+
data, sr = sox_io_backend.load(path)
139+
# 4. Load wav with scipy
140+
data_ref = load_wav(ref_path)[0]
141+
# 5. Compare
142+
assert sr == sample_rate
143+
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
144+
123145

124146
@skipIfNoExec('sox')
125147
@skipIfNoExtension
@@ -230,6 +252,14 @@ def test_opus(self, bitrate, num_channels, compression_level):
230252
assert sample_rate == sr
231253
self.assertEqual(expected, found)
232254

255+
@parameterized.expand(list(itertools.product(
256+
[8000, 16000],
257+
[1, 2],
258+
)), name_func=name_func)
259+
def test_sphere(self, sample_rate, num_channels):
260+
"""`sox_io_backend.load` can load sph format correctly."""
261+
self.assert_sphere(sample_rate, num_channels, duration=1)
262+
233263

234264
@skipIfNoExec('sox')
235265
@skipIfNoExtension

test/torchaudio_unittest/sox_io_backend/save_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,38 @@ def assert_vorbis(self, *args, **kwargs):
168168
else:
169169
raise error
170170

171+
def assert_sphere(self, sample_rate, num_channels, duration):
172+
"""`sox_io_backend.save` can save sph format.
173+
174+
This test takes the same strategy as mp3 to compare the result
175+
"""
176+
src_path = self.get_temp_path('1.reference.wav')
177+
flc_path = self.get_temp_path('2.1.torchaudio.sph')
178+
wav_path = self.get_temp_path('2.2.torchaudio.wav')
179+
flc_path_sox = self.get_temp_path('3.1.sox.sph')
180+
wav_path_sox = self.get_temp_path('3.2.sox.wav')
181+
182+
# 1. Generate original wav
183+
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
184+
save_wav(src_path, data, sample_rate)
185+
# 2.1. Convert the original wav to sph with torchaudio
186+
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate)
187+
# 2.2. Convert the sph to wav with Sox
188+
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
189+
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
190+
# 2.3. Load
191+
found = load_wav(wav_path)[0]
192+
193+
# 3.1. Convert the original wav to sph with SoX
194+
sox_utils.convert_audio_file(src_path, flc_path_sox)
195+
# 3.2. Convert the sph to wav with Sox
196+
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
197+
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
198+
# 3.3. Load
199+
expected = load_wav(wav_path_sox)[0]
200+
201+
self.assertEqual(found, expected)
202+
171203

172204
@skipIfNoExec('sox')
173205
@skipIfNoExtension
@@ -262,6 +294,14 @@ def test_vorbis_large(self, sample_rate, num_channels, quality_level):
262294
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
263295
'''
264296

297+
@parameterized.expand(list(itertools.product(
298+
[8000, 16000],
299+
[1, 2],
300+
)), name_func=name_func)
301+
def test_sphere(self, sample_rate, num_channels):
302+
"""`sox_io_backend.save` can save sph format."""
303+
self.assert_sphere(sample_rate, num_channels, duration=1)
304+
265305

266306
@skipIfNoExec('sox')
267307
@skipIfNoExtension

torchaudio/backend/sox_io_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def load(
5858
* FLAC
5959
* OGG/VORBIS
6060
* OPUS
61+
* SPHERE
6162
6263
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
6364
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
@@ -132,6 +133,7 @@ def save(
132133
* MP3
133134
* FLAC
134135
* OGG/VORBIS
136+
* SPHERE
135137
136138
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
137139
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
@@ -158,7 +160,7 @@ def save(
158160
"""
159161
if compression is None:
160162
ext = str(filepath)[-3:].lower()
161-
if ext == 'wav':
163+
if ext in ['wav', 'sph']:
162164
compression = 0.
163165
elif ext == 'mp3':
164166
compression = -4.5

torchaudio/csrc/sox_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ sox_encoding_t get_encoding(
234234
return SOX_ENCODING_FLOAT;
235235
throw std::runtime_error("Unsupported dtype.");
236236
}
237+
if (filetype == "sph")
238+
return SOX_ENCODING_SIGN2;
237239
throw std::runtime_error("Unsupported file type.");
238240
}
239241

@@ -257,6 +259,8 @@ unsigned get_precision(
257259
return 32;
258260
throw std::runtime_error("Unsupported dtype.");
259261
}
262+
if (filetype == "sph")
263+
return 32;
260264
throw std::runtime_error("Unsupported file type.");
261265
}
262266

@@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo(
285289
return compression;
286290
if (filetype == "wav")
287291
return 0.;
292+
if (filetype == "sph")
293+
return 0.;
288294
throw std::runtime_error("Unsupported file type.");
289295
}();
290296

0 commit comments

Comments
 (0)