Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels


@skipIfNoExtension
class TestInfoOpus(PytorchTestCase):
Expand Down
30 changes: 30 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,28 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_sphere(self, sample_rate, num_channels, duration):
"""`sox_io_backend.load` can load sph format.

This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.sph')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate sph with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load sph with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -230,6 +252,14 @@ def test_opus(self, bitrate, num_channels, compression_level):
assert sample_rate == sr
self.assertEqual(expected, found)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
40 changes: 40 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,38 @@ def assert_vorbis(self, *args, **kwargs):
else:
raise error

def assert_sphere(self, sample_rate, num_channels, duration):
"""`sox_io_backend.save` can save sph format.

This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
flc_path = self.get_temp_path('2.1.torchaudio.sph')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
flc_path_sox = self.get_temp_path('3.1.sox.sph')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to sph with SoX
sox_utils.convert_audio_file(src_path, flc_path_sox)
# 3.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

self.assertEqual(found, expected)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -262,6 +294,14 @@ def test_vorbis_large(self, sample_rate, num_channels, quality_level):
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.save` can save sph format."""
self.assert_sphere(sample_rate, num_channels, duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
4 changes: 3 additions & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def load(
* FLAC
* OGG/VORBIS
* OPUS
* SPHERE

To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand Down Expand Up @@ -132,6 +133,7 @@ def save(
* MP3
* FLAC
* OGG/VORBIS
* SPHERE

To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand All @@ -158,7 +160,7 @@ def save(
"""
if compression is None:
ext = str(filepath)[-3:].lower()
if ext == 'wav':
if ext in ['wav', 'sph']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/csrc/sox_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ sox_encoding_t get_encoding(
return SOX_ENCODING_FLOAT;
throw std::runtime_error("Unsupported dtype.");
}
if (filetype == "sph")
return SOX_ENCODING_SIGN2;
throw std::runtime_error("Unsupported file type.");
}

Expand All @@ -257,6 +259,8 @@ unsigned get_precision(
return 32;
throw std::runtime_error("Unsupported dtype.");
}
if (filetype == "sph")
return 32;
throw std::runtime_error("Unsupported file type.");
}

Expand Down Expand Up @@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo(
return compression;
if (filetype == "wav")
return 0.;
if (filetype == "sph")
return 0.;
throw std::runtime_error("Unsupported file type.");
}();

Expand Down