diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index b9fb30e114..8e6cd337cb 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -89,6 +89,8 @@ def _get_extra_objects(): 'libvorbisfile.a', 'libvorbis.a', 'libogg.a', + 'libopencore-amrnb.a', + 'libopencore-amrwb.a', ] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) diff --git a/test/torchaudio_unittest/sox_io_backend/info_test.py b/test/torchaudio_unittest/sox_io_backend/info_test.py index b16203cfd3..9b928f3ae0 100644 --- a/test/torchaudio_unittest/sox_io_backend/info_test.py +++ b/test/torchaudio_unittest/sox_io_backend/info_test.py @@ -122,6 +122,36 @@ def test_sphere(self, sample_rate, num_channels): assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_amb(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check amb file correctly""" + duration = 1 + path = self.get_temp_path('data.amb') + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + + def test_amr_nb(self): + """`sox_io_backend.info` can check amr-nb file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.amr-nb') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + @skipIfNoExtension class TestInfoOpus(PytorchTestCase): diff --git a/test/torchaudio_unittest/sox_io_backend/load_test.py b/test/torchaudio_unittest/sox_io_backend/load_test.py index 59bcaea386..156ad9ad3e 100644 --- a/test/torchaudio_unittest/sox_io_backend/load_test.py +++ b/test/torchaudio_unittest/sox_io_backend/load_test.py @@ -142,6 +142,53 @@ def assert_sphere(self, sample_rate, num_channels, duration): assert sr == sample_rate self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration): + """`sox_io_backend.load` can load amb format. + + This test takes the same strategy as mp3 to compare the result + """ + path = self.get_temp_path('1.original.amb') + ref_path = self.get_temp_path('2.reference.wav') + + # 1. Generate amb with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + encoding=sox_utils.get_encoding(dtype), + bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load amb with torchaudio + data, sr = sox_io_backend.load(path, normalize=normalize) + # 4. Load wav with scipy + data_ref = load_wav(ref_path, normalize=normalize)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + + def assert_amr_nb(self, duration): + """`sox_io_backend.load` can load amr-nb format. + + This test takes the same strategy as mp3 to compare the result + """ + sample_rate = 8000 + num_channels = 1 + path = self.get_temp_path('1.original.amr-nb') + ref_path = self.get_temp_path('2.reference.wav') + + # 1. Generate amr-nb with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + bit_depth=32, duration=duration) + # 2. Convert to wav with sox + sox_utils.convert_audio_file(path, ref_path) + # 3. Load amr-nb with torchaudio + data, sr = sox_io_backend.load(path) + # 4. Load wav with scipy + data_ref = load_wav(ref_path)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) + @skipIfNoExec('sox') @skipIfNoExtension @@ -260,6 +307,20 @@ def test_sphere(self, sample_rate, num_channels): """`sox_io_backend.load` can load sph format correctly.""" self.assert_sphere(sample_rate, num_channels, duration=1) + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16'], + [8000, 16000], + [1, 2], + [False, True], + )), name_func=name_func) + def test_amb(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load sph format correctly.""" + self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1) + + def test_amr_nb(self): + """`sox_io_backend.load` can load amr_nb format correctly.""" + self.assert_amr_nb(duration=1) + @skipIfNoExec('sox') @skipIfNoExtension diff --git a/test/torchaudio_unittest/sox_io_backend/save_test.py b/test/torchaudio_unittest/sox_io_backend/save_test.py index deed1aa1ad..b0ee25e01c 100644 --- a/test/torchaudio_unittest/sox_io_backend/save_test.py +++ b/test/torchaudio_unittest/sox_io_backend/save_test.py @@ -200,6 +200,68 @@ def assert_sphere(self, sample_rate, num_channels, duration): self.assertEqual(found, expected) + def assert_amb(self, dtype, sample_rate, num_channels, duration): + """`sox_io_backend.save` can save amb format. + + This test takes the same strategy as mp3 to compare the result + """ + src_path = self.get_temp_path('1.reference.wav') + amb_path = self.get_temp_path('2.1.torchaudio.amb') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + amb_path_sox = self.get_temp_path('3.1.sox.amb') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to amb with torchaudio + sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate) + # 2.2. Convert the amb to wav with Sox + sox_utils.convert_audio_file(amb_path, wav_path) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to amb with SoX + sox_utils.convert_audio_file(src_path, amb_path_sox) + # 3.2. Convert the amb to wav with Sox + sox_utils.convert_audio_file(amb_path_sox, wav_path_sox) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + self.assertEqual(found, expected) + + def assert_amr_nb(self, duration): + """`sox_io_backend.save` can save amr_nb format. + + This test takes the same strategy as mp3 to compare the result + """ + sample_rate = 8000 + num_channels = 1 + src_path = self.get_temp_path('1.reference.wav') + amr_path = self.get_temp_path('2.1.torchaudio.amr-nb') + wav_path = self.get_temp_path('2.2.torchaudio.wav') + amr_path_sox = self.get_temp_path('3.1.sox.amr-nb') + wav_path_sox = self.get_temp_path('3.2.sox.wav') + + # 1. Generate original wav + data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(src_path, data, sample_rate) + # 2.1. Convert the original wav to amr_nb with torchaudio + sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate) + # 2.2. Convert the amr_nb to wav with Sox + sox_utils.convert_audio_file(amr_path, wav_path) + # 2.3. Load + found = load_wav(wav_path)[0] + + # 3.1. Convert the original wav to amr_nb with SoX + sox_utils.convert_audio_file(src_path, amr_path_sox) + # 3.2. Convert the amr_nb to wav with Sox + sox_utils.convert_audio_file(amr_path_sox, wav_path_sox) + # 3.3. Load + expected = load_wav(wav_path_sox)[0] + + self.assertEqual(found, expected) + @skipIfNoExec('sox') @skipIfNoExtension @@ -302,6 +364,19 @@ def test_sphere(self, sample_rate, num_channels): """`sox_io_backend.save` can save sph format.""" self.assert_sphere(sample_rate, num_channels, duration=1) + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_amb(self, dtype, sample_rate, num_channels): + """`sox_io_backend.save` can save amb format.""" + self.assert_amb(dtype, sample_rate, num_channels, duration=1) + + def test_amr_nb(self): + """`sox_io_backend.save` can save amr-nb format.""" + self.assert_amr_nb(duration=1) + @skipIfNoExec('sox') @skipIfNoExtension diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index da068be8f3..ad0eac582b 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -16,6 +16,14 @@ ExternalProject_Add(libmad CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS} ) +ExternalProject_Add(amr + PREFIX ${CMAKE_CURRENT_SOURCE_DIR} + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz + URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341 + CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/amr/configure ${COMMON_ARGS} +) + ExternalProject_Add(libmp3lame PREFIX ${CMAKE_CURRENT_SOURCE_DIR} DOWNLOAD_DIR ${ARCHIVE_DIR} @@ -72,11 +80,11 @@ ExternalProject_Add(opusfile ExternalProject_Add(libsox PREFIX ${CMAKE_CURRENT_SOURCE_DIR} - DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad + DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad amr DOWNLOAD_DIR ${ARCHIVE_DIR} URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2 URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c # OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses. # See https://github.com/pytorch/audio/pull/1026 - CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --disable-openmp + CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp ) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 1e57e92834..5293223977 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -40,18 +40,19 @@ def load( This function can handle all the codecs that underlying libsox can handle, however it is tested on the following formats; - * WAV + * WAV, AMB * 32-bit floating-point * 32-bit signed integer * 16-bit signed integer - * 8-bit unsigned integer + * 8-bit unsigned integer (WAV only) * MP3 * FLAC * OGG/VORBIS * OPUS * SPHERE + * AMR-NB To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` @@ -119,7 +120,7 @@ def save( Note: Supported formats are; - * WAV + * WAV, AMB * 32-bit floating-point * 32-bit signed integer @@ -130,6 +131,7 @@ def save( * FLAC * OGG/VORBIS * SPHERE + * AMR-NB To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` @@ -160,7 +162,7 @@ def save( filepath = str(filepath) if compression is None: ext = str(filepath).split('.')[-1].lower() - if ext in ['wav', 'sph']: + if ext in ['wav', 'sph', 'amb', 'amr-nb']: compression = 0. elif ext == 'mp3': compression = -4.5 diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index d0cc2b4f4e..9acab1c030 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -85,11 +85,17 @@ void save_audio_file( const std::string& file_name, const c10::intrusive_ptr& signal, const double compression) { - const auto tensor = signal->getTensor(); + auto tensor = signal->tensor; validate_input_tensor(tensor); const auto filetype = get_filetype(file_name); + if (filetype == "amr-nb") { + const auto num_channels = tensor.size(signal->channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "amr-nb format only supports single channel audio."); + tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); + } const auto signal_info = get_signalinfo(signal.get(), filetype); const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp index 25fe376c8b..656cc63348 100644 --- a/torchaudio/csrc/sox_utils.cpp +++ b/torchaudio/csrc/sox_utils.cpp @@ -223,7 +223,7 @@ sox_encoding_t get_encoding( return SOX_ENCODING_FLAC; if (filetype == "ogg" || filetype == "vorbis") return SOX_ENCODING_VORBIS; - if (filetype == "wav") { + if (filetype == "wav" || filetype == "amb") { if (dtype == torch::kUInt8) return SOX_ENCODING_UNSIGNED; if (dtype == torch::kInt16) @@ -236,7 +236,9 @@ sox_encoding_t get_encoding( } if (filetype == "sph") return SOX_ENCODING_SIGN2; - throw std::runtime_error("Unsupported file type."); + if (filetype == "amr-nb") + return SOX_ENCODING_AMR_NB; + throw std::runtime_error("Unsupported file type: " + filetype); } unsigned get_precision( @@ -248,7 +250,7 @@ unsigned get_precision( return 24; if (filetype == "ogg" || filetype == "vorbis") return SOX_UNSPEC; - if (filetype == "wav") { + if (filetype == "wav" || filetype == "amb") { if (dtype == torch::kUInt8) return 8; if (dtype == torch::kInt16) @@ -261,7 +263,13 @@ unsigned get_precision( } if (filetype == "sph") return 32; - throw std::runtime_error("Unsupported file type."); + if (filetype == "amr-nb") { + TORCH_INTERNAL_ASSERT( + dtype == torch::kInt16, + "When saving to AMR-NB format, the input tensor must be int16 type."); + return 16; + } + throw std::runtime_error("Unsupported file type: " + filetype); } sox_signalinfo_t get_signalinfo( @@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo( return compression; if (filetype == "ogg" || filetype == "vorbis") return compression; - if (filetype == "wav") + if (filetype == "wav" || filetype == "amb") return 0.; if (filetype == "sph") return 0.; - throw std::runtime_error("Unsupported file type."); + if (filetype == "amr-nb") + return 0.; + throw std::runtime_error("Unsupported file type: " + filetype); }(); return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),