From f78ebf9556cf237de9e2fb8b3b156dca90769435 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 10 Aug 2020 19:16:34 +0000 Subject: [PATCH] Add SPHERE format support --- .../sox_io_backend/info_test.py | 14 +++++++ .../sox_io_backend/load_test.py | 30 ++++++++++++++ .../sox_io_backend/save_test.py | 40 +++++++++++++++++++ torchaudio/backend/sox_io_backend.py | 4 +- torchaudio/csrc/sox_utils.cpp | 6 +++ 5 files changed, 93 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/sox_io_backend/info_test.py b/test/torchaudio_unittest/sox_io_backend/info_test.py index f07b72e112..b16203cfd3 100644 --- a/test/torchaudio_unittest/sox_io_backend/info_test.py +++ b/test/torchaudio_unittest/sox_io_backend/info_test.py @@ -108,6 +108,20 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_sphere(self, sample_rate, num_channels): + """`sox_io_backend.info` can check sph file correctly""" + duration = 1 + path = self.get_temp_path('data.sph') + sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + @skipIfNoExtension class TestInfoOpus(PytorchTestCase): diff --git a/test/torchaudio_unittest/sox_io_backend/load_test.py b/test/torchaudio_unittest/sox_io_backend/load_test.py index 1ff8b76dc5..59bcaea386 100644 --- a/test/torchaudio_unittest/sox_io_backend/load_test.py +++ b/test/torchaudio_unittest/sox_io_backend/load_test.py @@ -120,6 +120,28 @@ def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): assert sr == sample_rate self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + def assert_sphere(self, sample_rate, num_channels, duration): + """`sox_io_backend.load` can load sph format. + + This test takes the same strategy as mp3 to compare the result + """ + path = self.get_temp_path('1.original.sph') + ref_path = self.get_temp_path('2.reference.wav') + + # 1. Generate sph with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + bit_depth=32, duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load sph with torchaudio + data, sr = sox_io_backend.load(path) + # 4. Load wav with scipy + data_ref = load_wav(ref_path)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + @skipIfNoExec('sox') @skipIfNoExtension @@ -230,6 +252,14 @@ def test_opus(self, bitrate, num_channels, compression_level): assert sample_rate == sr self.assertEqual(expected, found) + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_sphere(self, sample_rate, num_channels): + """`sox_io_backend.load` can load sph format correctly.""" + self.assert_sphere(sample_rate, num_channels, duration=1) + @skipIfNoExec('sox') @skipIfNoExtension diff --git a/test/torchaudio_unittest/sox_io_backend/save_test.py b/test/torchaudio_unittest/sox_io_backend/save_test.py index 8e416b7aa9..f7d4da3931 100644 --- a/test/torchaudio_unittest/sox_io_backend/save_test.py +++ b/test/torchaudio_unittest/sox_io_backend/save_test.py @@ -168,6 +168,38 @@ def assert_vorbis(self, *args, **kwargs): else: raise error + def assert_sphere(self, sample_rate, num_channels, duration): + """`sox_io_backend.save` can save sph format. + + This test takes the same strategy as mp3 to compare the result + """ + src_path = self.get_temp_path('1.reference.wav') + flc_path = self.get_temp_path('2.1.torchaudio.sph') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + flc_path_sox = self.get_temp_path('3.1.sox.sph') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to sph with torchaudio + sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate) + # 2.2. Convert the sph to wav with Sox + # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. + sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to sph with SoX + sox_utils.convert_audio_file(src_path, flc_path_sox) + # 3.2. Convert the sph to wav with Sox + # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. + sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + self.assertEqual(found, expected) + @skipIfNoExec('sox') @skipIfNoExtension @@ -262,6 +294,14 @@ def test_vorbis_large(self, sample_rate, num_channels, quality_level): self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) ''' + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_sphere(self, sample_rate, num_channels): + """`sox_io_backend.save` can save sph format.""" + self.assert_sphere(sample_rate, num_channels, duration=1) + @skipIfNoExec('sox') @skipIfNoExtension diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 57d8db7723..9c0695d723 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -58,6 +58,7 @@ def load( * FLAC * OGG/VORBIS * OPUS + * SPHERE To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` @@ -132,6 +133,7 @@ def save( * MP3 * FLAC * OGG/VORBIS + * SPHERE To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` @@ -158,7 +160,7 @@ def save( """ if compression is None: ext = str(filepath)[-3:].lower() - if ext == 'wav': + if ext in ['wav', 'sph']: compression = 0. elif ext == 'mp3': compression = -4.5 diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp index c6b73c5795..25fe376c8b 100644 --- a/torchaudio/csrc/sox_utils.cpp +++ b/torchaudio/csrc/sox_utils.cpp @@ -234,6 +234,8 @@ sox_encoding_t get_encoding( return SOX_ENCODING_FLOAT; throw std::runtime_error("Unsupported dtype."); } + if (filetype == "sph") + return SOX_ENCODING_SIGN2; throw std::runtime_error("Unsupported file type."); } @@ -257,6 +259,8 @@ unsigned get_precision( return 32; throw std::runtime_error("Unsupported dtype."); } + if (filetype == "sph") + return 32; throw std::runtime_error("Unsupported file type."); } @@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo( return compression; if (filetype == "wav") return 0.; + if (filetype == "sph") + return 0.; throw std::runtime_error("Unsupported file type."); }();