Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _get_extra_objects():
'libvorbisfile.a',
'libvorbis.a',
'libogg.a',
'libopencore-amrnb.a',
'libopencore-amrwb.a',
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
Expand Down
30 changes: 30 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ def test_sphere(self, sample_rate, num_channels):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels


@skipIfNoExtension
class TestInfoOpus(PytorchTestCase):
Expand Down
61 changes: 61 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,53 @@ def assert_sphere(self, sample_rate, num_channels, duration):
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amb format.

This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.amb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amb with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_amr_nb(self, duration):
"""`sox_io_backend.load` can load amr-nb format.

This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
path = self.get_temp_path('1.original.amr-nb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amr-nb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amr-nb with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -260,6 +307,20 @@ def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1)

def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_amr_nb(duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
75 changes: 75 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,68 @@ def assert_sphere(self, sample_rate, num_channels, duration):

self.assertEqual(found, expected)

def assert_amb(self, dtype, sample_rate, num_channels, duration):
"""`sox_io_backend.save` can save amb format.

This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
amb_path = self.get_temp_path('2.1.torchaudio.amb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amb_path_sox = self.get_temp_path('3.1.sox.amb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to amb with SoX
sox_utils.convert_audio_file(src_path, amb_path_sox)
# 3.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

self.assertEqual(found, expected)

def assert_amr_nb(self, duration):
"""`sox_io_backend.save` can save amr_nb format.

This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
src_path = self.get_temp_path('1.reference.wav')
amr_path = self.get_temp_path('2.1.torchaudio.amr-nb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amr_path_sox = self.get_temp_path('3.1.sox.amr-nb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to amr_nb with SoX
sox_utils.convert_audio_file(src_path, amr_path_sox)
# 3.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

self.assertEqual(found, expected)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -302,6 +364,19 @@ def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.save` can save sph format."""
self.assert_sphere(sample_rate, num_channels, duration=1)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save amb format."""
self.assert_amb(dtype, sample_rate, num_channels, duration=1)

def test_amr_nb(self):
"""`sox_io_backend.save` can save amr-nb format."""
self.assert_amr_nb(duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
12 changes: 10 additions & 2 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ ExternalProject_Add(libmad
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS}
)

ExternalProject_Add(amr
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz
URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/amr/configure ${COMMON_ARGS}
)

ExternalProject_Add(libmp3lame
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
Expand Down Expand Up @@ -72,11 +80,11 @@ ExternalProject_Add(opusfile

ExternalProject_Add(libsox
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad amr
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --disable-openmp
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
)
10 changes: 6 additions & 4 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@ def load(
This function can handle all the codecs that underlying libsox can handle,
however it is tested on the following formats;

* WAV
* WAV, AMB

* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* 8-bit unsigned integer (WAV only)

* MP3
* FLAC
* OGG/VORBIS
* OPUS
* SPHERE
* AMR-NB

To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand Down Expand Up @@ -119,7 +120,7 @@ def save(
Note:
Supported formats are;

* WAV
* WAV, AMB

* 32-bit floating-point
* 32-bit signed integer
Expand All @@ -130,6 +131,7 @@ def save(
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB

To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand Down Expand Up @@ -160,7 +162,7 @@ def save(
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph']:
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
Expand Down
8 changes: 7 additions & 1 deletion torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression) {
const auto tensor = signal->getTensor();
auto tensor = signal->tensor;

validate_input_tensor(tensor);

const auto filetype = get_filetype(file_name);
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(signal->channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(signal.get(), filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
Expand Down
22 changes: 16 additions & 6 deletions torchaudio/csrc/sox_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ sox_encoding_t get_encoding(
return SOX_ENCODING_FLAC;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_ENCODING_VORBIS;
if (filetype == "wav") {
if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
Expand All @@ -236,7 +236,9 @@ sox_encoding_t get_encoding(
}
if (filetype == "sph")
return SOX_ENCODING_SIGN2;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb")
return SOX_ENCODING_AMR_NB;
throw std::runtime_error("Unsupported file type: " + filetype);
}

unsigned get_precision(
Expand All @@ -248,7 +250,7 @@ unsigned get_precision(
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav") {
if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
Expand All @@ -261,7 +263,13 @@ unsigned get_precision(
}
if (filetype == "sph")
return 32;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb") {
TORCH_INTERNAL_ASSERT(
dtype == torch::kInt16,
"When saving to AMR-NB format, the input tensor must be int16 type.");
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
}

sox_signalinfo_t get_signalinfo(
Expand All @@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo(
return compression;
if (filetype == "ogg" || filetype == "vorbis")
return compression;
if (filetype == "wav")
if (filetype == "wav" || filetype == "amb")
return 0.;
if (filetype == "sph")
return 0.;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb")
return 0.;
throw std::runtime_error("Unsupported file type: " + filetype);
}();

return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
Expand Down