From c1c8ed302c3dab4b0d8015957a579dd1e26885ad Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 2 Feb 2021 23:19:00 +0000 Subject: [PATCH 01/12] Add encoding and bits_per_sample option to save function --- .../backend/sox_io/common.py | 12 + .../backend/sox_io/roundtrip_test.py | 4 +- .../backend/sox_io/save_test.py | 721 +++++++----------- .../backend/sox_io/torchscript_test.py | 22 +- .../common_utils/sox_utils.py | 12 +- torchaudio/backend/sox_io_backend.py | 95 ++- torchaudio/csrc/sox/effects_chain.cpp | 38 +- torchaudio/csrc/sox/io.cpp | 70 +- torchaudio/csrc/sox/io.h | 14 +- torchaudio/csrc/sox/utils.cpp | 243 ++++-- torchaudio/csrc/sox/utils.h | 25 +- 11 files changed, 635 insertions(+), 621 deletions(-) diff --git a/test/torchaudio_unittest/backend/sox_io/common.py b/test/torchaudio_unittest/backend/sox_io/common.py index eb85937236..c2538b2bc4 100644 --- a/test/torchaudio_unittest/backend/sox_io/common.py +++ b/test/torchaudio_unittest/backend/sox_io/common.py @@ -1,2 +1,14 @@ def name_func(func, _, params): return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + +def get_enc_params(dtype): + if dtype == 'float32': + return 'PCM_F', 32 + if dtype == 'int32': + return 'PCM_S', 32 + if dtype == 'int16': + return 'PCM_S', 16 + if dtype == 'uint8': + return 'PCM_U', 8 + raise ValueError(f'Unexpected dtype: {dtype}') diff --git a/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py b/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py index 85aa3019e0..8fff786388 100644 --- a/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py +++ b/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py @@ -12,6 +12,7 @@ ) from .common import ( name_func, + get_enc_params, ) @@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): def test_wav(self, dtype, sample_rate, num_channels): """save/load round trip should not degrade data for wav formats""" original = get_wav_data(dtype, num_channels, normalize=False) + enc, bps = get_enc_params(dtype) data = original for i in range(10): path = self.get_temp_path(f'{i}.wav') - sox_io_backend.save(path, data, sample_rate) + sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps) data, sr = sox_io_backend.load(path, normalize=False) assert sr == sample_rate self.assertEqual(original, data) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index fc378f9c70..dc06d48ce7 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -1,12 +1,13 @@ import io -import itertools +import unittest +from itertools import product -import torch from torchaudio.backend import sox_io_backend from parameterized import parameterized from torchaudio_unittest.common_utils import ( TempDirMixin, + TorchaudioTestCase, PytorchTestCase, skipIfNoExec, skipIfNoExtension, @@ -17,37 +18,61 @@ ) from .common import ( name_func, + get_enc_params, ) -class SaveTestBase(TempDirMixin, PytorchTestCase): - def assert_wav(self, dtype, sample_rate, num_channels, num_frames): - """`sox_io_backend.save` can save wav format.""" - path = self.get_temp_path('data.wav') - expected = get_wav_data(dtype, num_channels, num_frames=num_frames) - sox_io_backend.save(path, expected, sample_rate, dtype=None) - found, sr = load_wav(path) - assert sample_rate == sr - self.assertEqual(found, expected) - - def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): - """`sox_io_backend.save` can save mp3 format. - - mp3 encoding introduces delay and boundary effects so - we convert the resulting mp3 to wav and compare the results there - - | - | 1. Generate original wav file with SciPy +def _get_sox_encoding(encoding): + encodings = { + 'PCM_F': 'floating-point', + 'PCM_S': 'signed-integer', + 'PCM_U': 'unsigned-integer', + 'ULAW': 'u-law', + 'ALAW': 'a-law', + } + return encodings.get(encoding) + + +class SaveTestBase(TempDirMixin, TorchaudioTestCase): + def assert_save_consistency( + self, + format: str, + *, + compression: float = None, + encoding: str = None, + bits_per_sample: int = None, + sample_rate: float = 8000, + num_channels: int = 2, + num_frames: float = 3 * 8000, + test_mode: str = "path", + ): + """`save` function produces file that is comparable with `sox` command + + To compare that the file produced by `save` function agains the file produced by + the equivalent `sox` command, we need to load both files. + But there are many formats that cannot be opened with common Python modules (like + SciPy). + So we use `sox` command to prepare the original data and convert the saved files + into a format that SciPy can read (PCM wav). + The following diagram illustrates this process. The difference is 2.1. and 3.1. + + This assumes that + - loading data with SciPy preserves the data well. + - converting the resulting files into WAV format with `sox` preserve the data well. + + x + | 1. Generate source wav file with SciPy | v -------------- wav ---------------- | | - | 2.1. load with scipy | 3.1. Convert to mp3 with Sox - | then save with torchaudio | + | 2.1. load with scipy | 3.1. Convert to the target + | then save it into the target | format depth with sox + | format with torchaudio | v v - mp3 mp3 + target format target format | | - | 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox + | 2.2. Convert to wav with sox | 3.2. Convert to wav with sox | | v v wav wav @@ -58,326 +83,242 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): tensor -------> compare <--------- tensor """ - src_path = self.get_temp_path('1.reference.wav') - mp3_path = self.get_temp_path('2.1.torchaudio.mp3') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - mp3_path_sox = self.get_temp_path('3.1.sox.mp3') - wav_path_sox = self.get_temp_path('3.2.sox.wav') + cmp_encoding = 'floating-point' + cmp_bit_depth = 32 - # 1. Generate original wav - data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) - save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to mp3 with torchaudio - sox_io_backend.save( - mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate, dtype=None) - # 2.2. Convert the mp3 to wav with Sox - sox_utils.convert_audio_file(mp3_path, wav_path) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to mp3 with SoX - sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) - # 3.2. Convert the mp3 to wav with Sox - sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] - - self.assertEqual(found, expected) - - def assert_flac(self, sample_rate, num_channels, compression_level, duration): - """`sox_io_backend.save` can save flac format. - - This test takes the same strategy as mp3 to compare the result - """ - src_path = self.get_temp_path('1.reference.wav') - flc_path = self.get_temp_path('2.1.torchaudio.flac') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - flc_path_sox = self.get_temp_path('3.1.sox.flac') - wav_path_sox = self.get_temp_path('3.2.sox.wav') + src_path = self.get_temp_path('1.source.wav') + tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}') + tst_path = self.get_temp_path('2.2.result.wav') + sox_path = self.get_temp_path(f'3.1.sox.{format}') + ref_path = self.get_temp_path('3.2.ref.wav') # 1. Generate original wav - data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) + data = get_wav_data('int32', num_channels, normalize=False, num_frames=num_frames) save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to flac with torchaudio - sox_io_backend.save( - flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level, dtype=None) - # 2.2. Convert the flac to wav with Sox - # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. - sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to flac with SoX - sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level) - # 3.2. Convert the flac to wav with Sox - # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. - sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] - - self.assertEqual(found, expected) - - def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration): - """`sox_io_backend.save` can save vorbis format. - - This test takes the same strategy as mp3 to compare the result - """ - src_path = self.get_temp_path('1.reference.wav') - vbs_path = self.get_temp_path('2.1.torchaudio.vorbis') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - vbs_path_sox = self.get_temp_path('3.1.sox.vorbis') - wav_path_sox = self.get_temp_path('3.2.sox.wav') - # 1. Generate original wav - data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) - save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to vorbis with torchaudio - sox_io_backend.save( - vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None) - # 2.2. Convert the vorbis to wav with Sox - sox_utils.convert_audio_file(vbs_path, wav_path) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to vorbis with SoX - sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level) - # 3.2. Convert the vorbis to wav with Sox - sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] - - # sox's vorbis encoding has some random boundary effect, which cause small number of - # samples yields higher descrepency than the others. - # so we allow small portions of data to be outside of absolute torelance. - # make sure to pass somewhat long duration - atol = 1.0e-4 - max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol. - failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel() - if failure_ratio > max_failure_allowed: - # it's failed and this will give a better error message. - self.assertEqual(found, expected, atol=atol, rtol=1.3e-6) - - def assert_vorbis(self, *args, **kwargs): - # sox's vorbis encoding has some randomness, so we run tests multiple time - max_retry = 5 - error = None - for _ in range(max_retry): - try: - self._assert_vorbis(*args, **kwargs) - break - except AssertionError as e: - error = e + # 2.1. Convert the original wav to target format with torchaudio + data = load_wav(src_path, normalize=False)[0] + if test_mode == "path": + sox_io_backend.save( + tgt_path, data, sample_rate, + compression=compression, encoding=encoding, bits_per_sample=bits_per_sample) + elif test_mode == "fileobj": + with open(tgt_path, 'bw') as file_: + sox_io_backend.save( + file_, data, sample_rate, + format=format, compression=compression, + encoding=encoding, bits_per_sample=bits_per_sample) + elif test_mode == "bytesio": + file_ = io.BytesIO() + sox_io_backend.save( + file_, data, sample_rate, + format=format, compression=compression, + encoding=encoding, bits_per_sample=bits_per_sample) + file_.seek(0) + with open(tgt_path, 'bw') as f: + f.write(file_.read()) else: - raise error - - def assert_sphere(self, sample_rate, num_channels, duration): - """`sox_io_backend.save` can save sph format. - - This test takes the same strategy as mp3 to compare the result - """ - src_path = self.get_temp_path('1.reference.wav') - flc_path = self.get_temp_path('2.1.torchaudio.sph') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - flc_path_sox = self.get_temp_path('3.1.sox.sph') - wav_path_sox = self.get_temp_path('3.2.sox.wav') - - # 1. Generate original wav - data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) - save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to sph with torchaudio - sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate, dtype=None) - # 2.2. Convert the sph to wav with Sox - # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. - sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to sph with SoX - sox_utils.convert_audio_file(src_path, flc_path_sox) - # 3.2. Convert the sph to wav with Sox - # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. - sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] + raise ValueError(f"Unexpected test mode: {test_mode}") + # 2.2. Convert the target format to wav with sox + sox_utils.convert_audio_file( + tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 2.3. Load with SciPy + found = load_wav(tst_path, normalize=False)[0] + + # 3.1. Convert the original wav to target format with sox + sox_encoding = _get_sox_encoding(encoding) + sox_utils.convert_audio_file( + src_path, sox_path, + compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample) + # 3.2. Convert the target format to wav with sox + sox_utils.convert_audio_file( + sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 3.3. Load with SciPy + expected = load_wav(ref_path, normalize=False)[0] self.assertEqual(found, expected) - def assert_amb(self, dtype, sample_rate, num_channels, duration): - """`sox_io_backend.save` can save amb format. - This test takes the same strategy as mp3 to compare the result - """ - src_path = self.get_temp_path('1.reference.wav') - amb_path = self.get_temp_path('2.1.torchaudio.amb') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - amb_path_sox = self.get_temp_path('3.1.sox.amb') - wav_path_sox = self.get_temp_path('3.2.sox.wav') +def nested_params(*params): + def _name_func(func, _, params): + strs = [] + for arg in params.args: + if isinstance(arg, tuple): + strs.append("_".join(str(a) for a in arg)) + else: + strs.append(str(arg)) + return f'{func.__name__}_{"_".join(strs)}' - # 1. Generate original wav - data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) - save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to amb with torchaudio - sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None) - # 2.2. Convert the amb to wav with Sox - sox_utils.convert_audio_file(amb_path, wav_path) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to amb with SoX - sox_utils.convert_audio_file(src_path, amb_path_sox) - # 3.2. Convert the amb to wav with Sox - sox_utils.convert_audio_file(amb_path_sox, wav_path_sox) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] - - self.assertEqual(found, expected) - - def assert_amr_nb(self, duration): - """`sox_io_backend.save` can save amr_nb format. - - This test takes the same strategy as mp3 to compare the result - """ - sample_rate = 8000 - num_channels = 1 - src_path = self.get_temp_path('1.reference.wav') - amr_path = self.get_temp_path('2.1.torchaudio.amr-nb') - wav_path = self.get_temp_path('2.2.torchaudio.wav') - amr_path_sox = self.get_temp_path('3.1.sox.amr-nb') - wav_path_sox = self.get_temp_path('3.2.sox.wav') - - # 1. Generate original wav - data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) - save_wav(src_path, data, sample_rate) - # 2.1. Convert the original wav to amr_nb with torchaudio - sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None) - # 2.2. Convert the amr_nb to wav with Sox - sox_utils.convert_audio_file(amr_path, wav_path) - # 2.3. Load - found = load_wav(wav_path)[0] - - # 3.1. Convert the original wav to amr_nb with SoX - sox_utils.convert_audio_file(src_path, amr_path_sox) - # 3.2. Convert the amr_nb to wav with Sox - sox_utils.convert_audio_file(amr_path_sox, wav_path_sox) - # 3.3. Load - expected = load_wav(wav_path_sox)[0] - - self.assertEqual(found, expected) + return parameterized.expand( + list(product(*params)), + name_func=_name_func + ) @skipIfNoExec('sox') @skipIfNoExtension -class TestSave(SaveTestBase): - @parameterized.expand(list(itertools.product( - ['float32', 'int32', 'int16', 'uint8'], - [8000, 16000], - [1, 2], - )), name_func=name_func) - def test_wav(self, dtype, sample_rate, num_channels): - """`sox_io_backend.save` can save wav format.""" - self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) - - @parameterized.expand(list(itertools.product( - ['float32'], - [16000], - [2], - )), name_func=name_func) - def test_wav_large(self, dtype, sample_rate, num_channels): - """`sox_io_backend.save` can save large wav file.""" - two_hours = 2 * 60 * 60 * sample_rate - self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) - - @parameterized.expand(list(itertools.product( - ['float32', 'int32', 'int16', 'uint8'], - [4, 8, 16, 32], - )), name_func=name_func) - def test_multiple_channels(self, dtype, num_channels): - """`sox_io_backend.save` can save wav with more than 2 channels.""" +class SaveTest(SaveTestBase): + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_U', 8), + ('PCM_S', 16), + ('PCM_S', 32), + ('PCM_F', 32), + ('ULAW', 8), + ('ALAW', 8), + ], + ) + def test_save_wav(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + -4.2, + -0.2, + 0, + 0.2, + 96, + 128, + 160, + 192, + 224, + 256, + 320, + ], + ) + def test_save_mp3(self, test_mode, bit_rate): + if test_mode in ["fileobj", "bytesio"]: + if bit_rate is not None and bit_rate < 1: + raise unittest.SkipTest( + "mp3 format with variable bit rate is known to " + "not yield the exact same result as sox command.") + self.assert_save_consistency( + "mp3", compression=bit_rate, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ], + ) + def test_save_flac(self, test_mode, compression_level): + self.assert_save_consistency( + "flac", compression=compression_level, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + -1, + 0, + 1, + 2, + 3, + 3.6, + 5, + 10, + ], + ) + def test_save_vorbis(self, test_mode, quality_level): + self.assert_save_consistency( + "vorbis", compression=quality_level, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_S', 8, ), + ('PCM_S', 16, ), + ('PCM_S', 24, ), + ('PCM_S', 32, ), + ('ULAW', 8), + ('ALAW', 8), + ('ALAW', 16), + ('ALAW', 24), + ('ALAW', 32), + ], + ) + def test_save_sphere(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_U', 8, ), + ('PCM_S', 16, ), + ('PCM_S', 24, ), + ('PCM_S', 32, ), + ('PCM_F', 32, ), + ('PCM_F', 64, ), + ('ULAW', 8, ), + ('ALAW', 8, ), + ], + ) + def test_save_amb(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ], + ) + def test_save_amr_nb(self, test_mode, bit_rate): + self.assert_save_consistency( + "amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode) + + @parameterized.expand([ + ("wav", "PCM_S", 16), + ("mp3", ), + ("flac", ), + ("vorbis", ), + ("sph", "PCM_S", 16), + ("amr-nb", ), + ("amb", "PCM_S", 16), + ], name_func=name_func) + def test_save_large(self, format, encoding=None, bits_per_sample=None): + """`sox_io_backend.save` can save large files.""" sample_rate = 8000 - self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) - - @parameterized.expand(list(itertools.product( - [8000, 16000], - [1, 2], - [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], - )), name_func=name_func) - def test_mp3(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.save` can save mp3 format.""" - self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) - - @parameterized.expand(list(itertools.product( - [16000], - [2], - [128], - )), name_func=name_func) - def test_mp3_large(self, sample_rate, num_channels, bit_rate): - """`sox_io_backend.save` can save large mp3 file.""" - two_hours = 2 * 60 * 60 - self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours) - - @parameterized.expand(list(itertools.product( - [8000, 16000], - [1, 2], - [None] + list(range(9)), - )), name_func=name_func) - def test_flac(self, sample_rate, num_channels, compression_level): - """`sox_io_backend.save` can save flac format.""" - self.assert_flac(sample_rate, num_channels, compression_level, duration=1) - - @parameterized.expand(list(itertools.product( - [16000], - [2], - [0], - )), name_func=name_func) - def test_flac_large(self, sample_rate, num_channels, compression_level): - """`sox_io_backend.save` can save large flac file.""" - two_hours = 2 * 60 * 60 - self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours) - - @parameterized.expand(list(itertools.product( - [8000, 16000], - [1, 2], - [None, -1, 0, 1, 2, 3, 3.6, 5, 10], - )), name_func=name_func) - def test_vorbis(self, sample_rate, num_channels, quality_level): - """`sox_io_backend.save` can save vorbis format.""" - self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20) - - # note: torchaudio can load large vorbis file, but cannot save large volbis file - # the following test causes Segmentation fault - # - ''' - @parameterized.expand(list(itertools.product( - [16000], - [2], - [10], - )), name_func=name_func) - def test_vorbis_large(self, sample_rate, num_channels, quality_level): - """`sox_io_backend.save` can save large vorbis file correctly.""" - two_hours = 2 * 60 * 60 - self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) - ''' - - @parameterized.expand(list(itertools.product( - [8000, 16000], - [1, 2], - )), name_func=name_func) - def test_sphere(self, sample_rate, num_channels): - """`sox_io_backend.save` can save sph format.""" - self.assert_sphere(sample_rate, num_channels, duration=1) - - @parameterized.expand(list(itertools.product( - ['float32', 'int32', 'int16', 'uint8'], - [8000, 16000], - [1, 2], - )), name_func=name_func) - def test_amb(self, dtype, sample_rate, num_channels): - """`sox_io_backend.save` can save amb format.""" - self.assert_amb(dtype, sample_rate, num_channels, duration=1) - - def test_amr_nb(self): - """`sox_io_backend.save` can save amr-nb format.""" - self.assert_amr_nb(duration=1) + one_hour = 60 * 60 * sample_rate + self.assert_save_consistency( + format, num_channels=1, sample_rate=8000, num_frames=one_hour, + encoding=encoding, bits_per_sample=bits_per_sample) + + @parameterized.expand([ + (32, ), + (64, ), + (128, ), + (256, ), + ], name_func=name_func) + def test_save_multi_channels(self, num_channels): + """`sox_io_backend.save` can save audio with many channels""" + self.assert_save_consistency( + "wav", encoding="PCM_U", bits_per_sample=16, + num_channels=num_channels) @skipIfNoExec('sox') @@ -385,136 +326,40 @@ def test_amr_nb(self): class TestSaveParams(TempDirMixin, PytorchTestCase): """Test the correctness of optional parameters of `sox_io_backend.save`""" @parameterized.expand([(True, ), (False, )], name_func=name_func) - def test_channels_first(self, channels_first): + def test_save_channels_first(self, channels_first): """channels_first swaps axes""" path = self.get_temp_path('data.wav') - data = get_wav_data('int32', 2, channels_first=channels_first) + data = get_wav_data( + 'int16', 2, channels_first=channels_first, normalize=False) sox_io_backend.save( - path, data, 8000, channels_first=channels_first, dtype=None) - found = load_wav(path)[0] + path, data, 8000, channels_first=channels_first) + found = load_wav(path, normalize=False)[0] expected = data if channels_first else data.transpose(1, 0) self.assertEqual(found, expected) @parameterized.expand([ 'float32', 'int32', 'int16', 'uint8' ], name_func=name_func) - def test_noncontiguous(self, dtype): + def test_save_noncontiguous(self, dtype): """Noncontiguous tensors are saved correctly""" path = self.get_temp_path('data.wav') - expected = get_wav_data(dtype, 4)[::2, ::2] + enc, bps = get_enc_params(dtype) + expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] assert not expected.is_contiguous() - sox_io_backend.save(path, expected, 8000, dtype=None) - found = load_wav(path)[0] + sox_io_backend.save( + path, expected, 8000, encoding=enc, bits_per_sample=bps) + found = load_wav(path, normalize=False)[0] self.assertEqual(found, expected) @parameterized.expand([ 'float32', 'int32', 'int16', 'uint8', ]) - def test_tensor_preserve(self, dtype): + def test_save_tensor_preserve(self, dtype): """save function should not alter Tensor""" path = self.get_temp_path('data.wav') - expected = get_wav_data(dtype, 4)[::2, ::2] + expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] data = expected.clone() - sox_io_backend.save(path, data, 8000, dtype=None) + sox_io_backend.save(path, data, 8000) self.assertEqual(data, expected) - - @parameterized.expand([ - ('float32', torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)), - ('int32', torch.tensor([-2147483648, -1073741824, 0, 1073741824, 2147483647]).to(torch.int32)), - ('int16', torch.tensor([-32768, -16384, 0, 16384, 32767]).to(torch.int16)), - ('uint8', torch.tensor([0, 64, 128, 192, 255]).to(torch.uint8)), - ]) - def test_dtype_conversion(self, dtype, expected): - """`save` performs dtype conversion on float32 src tensors only.""" - path = self.get_temp_path("data.wav") - data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1) - sox_io_backend.save(path, data, 8000, dtype=dtype) - found = load_wav(path, normalize=False)[0] - self.assertEqual(found, expected.view(-1, 1)) - - -@skipIfNoExtension -@skipIfNoExec('sox') -class TestFileObject(SaveTestBase): - """ - We campare the result of file-like object input against file path input because - `save` function is rigrously tested for file path inputs to match libsox's result, - """ - @parameterized.expand([ - ('wav', None), - ('mp3', 128), - ('mp3', 320), - ('flac', 0), - ('flac', 5), - ('flac', 8), - ('vorbis', -1), - ('vorbis', 10), - ('amb', None), - ]) - def test_fileobj(self, ext, compression): - """Saving audio to file object returns the same result as via file path.""" - sample_rate = 16000 - dtype = 'float32' - num_channels = 2 - num_frames = 16000 - channels_first = True - - data = get_wav_data(dtype, num_channels, num_frames=num_frames) - - ref_path = self.get_temp_path(f'reference.{ext}') - res_path = self.get_temp_path(f'test.{ext}') - sox_io_backend.save( - ref_path, data, channels_first=channels_first, - sample_rate=sample_rate, compression=compression, dtype=None) - with open(res_path, 'wb') as fileobj: - sox_io_backend.save( - fileobj, data, channels_first=channels_first, - sample_rate=sample_rate, compression=compression, format=ext, dtype=None) - - expected_data, _ = sox_io_backend.load(ref_path) - data, sr = sox_io_backend.load(res_path) - - assert sample_rate == sr - self.assertEqual(expected_data, data) - - @parameterized.expand([ - ('wav', None), - ('mp3', 128), - ('mp3', 320), - ('flac', 0), - ('flac', 5), - ('flac', 8), - ('vorbis', -1), - ('vorbis', 10), - ('amb', None), - ]) - def test_bytesio(self, ext, compression): - """Saving audio to BytesIO object returns the same result as via file path.""" - sample_rate = 16000 - dtype = 'float32' - num_channels = 2 - num_frames = 16000 - channels_first = True - - data = get_wav_data(dtype, num_channels, num_frames=num_frames) - - ref_path = self.get_temp_path(f'reference.{ext}') - res_path = self.get_temp_path(f'test.{ext}') - sox_io_backend.save( - ref_path, data, channels_first=channels_first, - sample_rate=sample_rate, compression=compression, dtype=None) - fileobj = io.BytesIO() - sox_io_backend.save( - fileobj, data, channels_first=channels_first, - sample_rate=sample_rate, compression=compression, format=ext, dtype=None) - fileobj.seek(0) - with open(res_path, 'wb') as file_: - file_.write(fileobj.read()) - - expected_data, _ = sox_io_backend.load(ref_path) - data, sr = sox_io_backend.load(res_path) - - assert sample_rate == sr - self.assertEqual(expected_data, data) diff --git a/test/torchaudio_unittest/backend/sox_io/torchscript_test.py b/test/torchaudio_unittest/backend/sox_io/torchscript_test.py index dad678cbc0..4e429877fd 100644 --- a/test/torchaudio_unittest/backend/sox_io/torchscript_test.py +++ b/test/torchaudio_unittest/backend/sox_io/torchscript_test.py @@ -17,6 +17,7 @@ ) from .common import ( name_func, + get_enc_params, ) @@ -35,8 +36,12 @@ def py_save_func( sample_rate: int, channels_first: bool = True, compression: Optional[float] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, ): - torchaudio.save(filepath, tensor, sample_rate, channels_first, compression) + torchaudio.save( + filepath, tensor, sample_rate, channels_first, + compression, None, encoding, bits_per_sample) @skipIfNoExec('sox') @@ -102,15 +107,16 @@ def test_save_wav(self, dtype, sample_rate, num_channels): torch.jit.script(py_save_func).save(script_path) ts_save_func = torch.jit.load(script_path) - expected = get_wav_data(dtype, num_channels) + expected = get_wav_data(dtype, num_channels, normalize=False) py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') + enc, bps = get_enc_params(dtype) - py_save_func(py_path, expected, sample_rate, True, None) - ts_save_func(ts_path, expected, sample_rate, True, None) + py_save_func(py_path, expected, sample_rate, True, None, enc, bps) + ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps) - py_data, py_sr = load_wav(py_path) - ts_data, ts_sr = load_wav(ts_path) + py_data, py_sr = load_wav(py_path, normalize=False) + ts_data, ts_sr = load_wav(ts_path, normalize=False) self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, ts_sr) @@ -131,8 +137,8 @@ def test_save_flac(self, sample_rate, num_channels, compression_level): py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') - py_save_func(py_path, expected, sample_rate, True, compression_level) - ts_save_func(ts_path, expected, sample_rate, True, compression_level) + py_save_func(py_path, expected, sample_rate, True, compression_level, None, None) + ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None) # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. py_path_wav = f'{py_path}.wav' diff --git a/test/torchaudio_unittest/common_utils/sox_utils.py b/test/torchaudio_unittest/common_utils/sox_utils.py index 0267017393..fd0949f31a 100644 --- a/test/torchaudio_unittest/common_utils/sox_utils.py +++ b/test/torchaudio_unittest/common_utils/sox_utils.py @@ -1,3 +1,4 @@ +import sys import subprocess import warnings @@ -32,6 +33,7 @@ def gen_audio_file( command = [ 'sox', '-V3', # verbose + '--no-dither', # disable automatic dithering '-R', # -R is supposed to be repeatable, though the implementation looks suspicious # and not setting the seed to a fixed value. @@ -61,21 +63,23 @@ def gen_audio_file( ] if attenuation is not None: command += ['vol', f'-{attenuation}dB'] - print(' '.join(command)) + print(' '.join(command), file=sys.stderr) subprocess.run(command, check=True) def convert_audio_file( src_path, dst_path, - *, bit_depth=None, compression=None): + *, encoding=None, bit_depth=None, compression=None): """Convert audio file with `sox` command.""" - command = ['sox', '-V3', '-R', str(src_path)] + command = ['sox', '-V3', '--no-dither', '-R', str(src_path)] + if encoding is not None: + command += ['--encoding', str(encoding)] if bit_depth is not None: command += ['--bits', str(bit_depth)] if compression is not None: command += ['--compression', str(compression)] command += [dst_path] - print(' '.join(command)) + print(' '.join(command), file=sys.stderr) subprocess.run(command, check=True) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 77d172af57..67f9f3e0ed 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -1,5 +1,4 @@ import os -import warnings from typing import Tuple, Optional import torch @@ -152,26 +151,6 @@ def load( filepath, frame_offset, num_frames, normalize, channels_first, format) -@torch.jit.unused -def _save( - filepath: str, - src: torch.Tensor, - sample_rate: int, - channels_first: bool = True, - compression: Optional[float] = None, - format: Optional[str] = None, - dtype: Optional[str] = None, -): - if hasattr(filepath, 'write'): - if format is None: - raise RuntimeError('`format` is required when saving to file object.') - torchaudio._torchaudio.save_audio_fileobj( - filepath, src, sample_rate, channels_first, compression, format, dtype) - else: - torch.ops.torchaudio.sox_io_save_audio_file( - os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype) - - @_mod_utils.requires_module('torchaudio._torchaudio') def save( filepath: str, @@ -180,7 +159,8 @@ def save( channels_first: bool = True, compression: Optional[float] = None, format: Optional[str] = None, - dtype: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, ): """Save audio data to file. @@ -223,24 +203,65 @@ def save( | and lowest quality. Default: ``3``. See the detail at http://sox.sourceforge.net/soxformat.html. - format (str, optional): Output audio format. - This is required when the output audio format cannot be infered from - ``filepath``, (such as file extension or ``name`` attribute of the given file object). - dtype (str, optional): Output tensor dtype. - Valid values: ``"uint8", "int16", "int32", "float32", "float64", None`` - ``dtype=None`` means no conversion is performed. - ``dtype`` parameter is only effective for ``float32`` Tensor. + format (str, optional): + If provided, overwrite the audio format. This parameter is required in cases where + the ``filepath`` parameter is file-like object or ``filepath`` parameter represents + the path to a file on a local system but missing file extension or has different + extension. + When not provided, the value of file extension is used. + Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, + ``"amb"``, ``"flac"`` and ``"sph"``. + encoding (str, optional): + Changes the encoding for the supported formats, such as ``"wav"``, ``"sph"``. + and ``"amb"``. + Valid values are ``"PCM_S"`` (signed integer Linear PCM), ``"PCM_U"`` + (unsigned integer Linear PCM), ``"PCM_F"`` (floating point PCM), + ``"ULAW"`` (mu-law) and ``"ALAW"`` (a-law). + Different formats support different set of encodings. Providing a value that is not + supported by the format will not cause an error, but will fallback to its default value. + + If not provided, the default values are picked based on ``format`` and + ``bits_per_sample``; + + For ``"wav"`` and ``"amb"`` formats, the default value is; + - ``"PCM_U"`` if ``bits_per_sample=8`` + - ``"PCM_S"`` otherwise + For ``"sph"`` format, the default value is ``"PCM_S"``. + + bits_per_sample (int, optional): + Change the bit depth for the supported formats, such as ``"wav"``, ``"flac"``, + ``"sph"``, and ``"amb"``. + Valid values are ``8``, ``16``, ``32`` and ``64``. + Different formats support different set of encodings. Providing a value that is not + supported by the format will not cause an error, but will fallback to its default value. + + If not provided, the default values are picked based on ``format`` and ``"encoding"``; + + For ``"wav"`` format, the default value is; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` + + For ``"flac"`` format, the default value is ``24``. + + For ``"sph"`` format, the default value is; + - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided. + - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + + For ``"amb"`` format, the default value is; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` """ - if src.dtype == torch.float32 and dtype is None: - warnings.warn( - '`dtype` default value will be changed to `int16` in 0.9 release.' - 'Specify `dtype` to suppress this warning.' - ) if not torch.jit.is_scripting(): - _save(filepath, src, sample_rate, channels_first, compression, format, dtype) - return + if hasattr(filepath, 'write'): + torchaudio._torchaudio.save_audio_fileobj( + filepath, src, sample_rate, channels_first, compression, + format, encoding, bits_per_sample) + return + filepath = os.fspath(filepath) torch.ops.torchaudio.sox_io_save_audio_file( - filepath, src, sample_rate, channels_first, compression, format, dtype) + filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample) @_mod_utils.requires_module('torchaudio._torchaudio') diff --git a/torchaudio/csrc/sox/effects_chain.cpp b/torchaudio/csrc/sox/effects_chain.cpp index b98010f67b..041ba020bc 100644 --- a/torchaudio/csrc/sox/effects_chain.cpp +++ b/torchaudio/csrc/sox/effects_chain.cpp @@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { // Ensure that it's a multiple of the number of channels *osamp -= *osamp % num_channels; - // Slice the input Tensor and unnormalize the values + // Slice the input Tensor const auto tensor_ = [&]() { auto i_frame = index / num_channels; auto num_frames = *osamp / num_channels; auto t = (priv->channels_first) ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); - return unnormalize_wav(t.reshape({-1})).contiguous(); + return t.reshape({-1}).contiguous(); }(); - priv->index += *osamp; - - // Write data to SoxEffectsChain buffer. - auto ptr = tensor_.data_ptr(); - std::copy(ptr, ptr + *osamp, obuf); + // Convert to sox_sample_t (int32_t) and write to buffer + SOX_SAMPLE_LOCALS; + const auto dtype = tensor_.dtype(); + if (dtype == torch::kFloat32) { + auto ptr = tensor_.data_ptr(); + for (size_t i = 0; i < *osamp; ++i) { + obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips); + } + } else if (dtype == torch::kInt32) { + auto ptr = tensor_.data_ptr(); + for (size_t i = 0; i < *osamp; ++i) { + obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips); + } + } else if (dtype == torch::kInt16) { + auto ptr = tensor_.data_ptr(); + for (size_t i = 0; i < *osamp; ++i) { + obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips); + } + } else if (dtype == torch::kUInt8) { + auto ptr = tensor_.data_ptr(); + for (size_t i = 0; i < *osamp; ++i) { + obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips); + } + } else { + throw std::runtime_error("Unexpected dtype."); + } + priv->index += *osamp; return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; } @@ -430,7 +452,7 @@ int fileobj_output_flow( fflush(fp); // Copy the encoded chunk to python object. - fileobj->attr("write")(py::bytes(*buffer, *buffer_size)); + fileobj->attr("write")(py::bytes(*buffer, ftell(fp))); // Reset FILE* sf->tell_off = 0; diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index ac10911507..c311155d2b 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -14,32 +14,31 @@ namespace { std::string get_encoding(sox_encoding_t encoding) { switch (encoding) { - case SOX_ENCODING_UNKNOWN: - return "UNKNOWN"; case SOX_ENCODING_SIGN2: - return "PCM_S"; + return ENCODING_PCM_SIGNED; case SOX_ENCODING_UNSIGNED: - return "PCM_U"; + return ENCODING_PCM_UNSIGNED; case SOX_ENCODING_FLOAT: - return "PCM_F"; + return ENCODING_PCM_FLOAT; case SOX_ENCODING_FLAC: - return "FLAC"; + return ENCODING_FLAC; case SOX_ENCODING_ULAW: - return "ULAW"; + return ENCODING_ULAW; case SOX_ENCODING_ALAW: - return "ALAW"; + return ENCODING_ALAW; case SOX_ENCODING_MP3: - return "MP3"; + return ENCODING_MP3; case SOX_ENCODING_VORBIS: - return "VORBIS"; + return ENCODING_VORBIS; case SOX_ENCODING_AMR_WB: - return "AMR_WB"; + return ENCODING_AMR_WB; case SOX_ENCODING_AMR_NB: - return "AMR_NB"; + return ENCODING_AMR_NB; case SOX_ENCODING_OPUS: - return "OPUS"; + return ENCODING_OPUS; + case SOX_ENCODING_UNKNOWN: default: - return "UNKNOWN"; + return ENCODING_UNKNOWN; } } @@ -116,35 +115,27 @@ void save_audio_file( torch::Tensor tensor, int64_t sample_rate, bool channels_first, - c10::optional compression, - c10::optional format, - c10::optional dtype) { + c10::optional& compression, + c10::optional& format, + c10::optional& encoding, + c10::optional& bits_per_sample) { validate_input_tensor(tensor); - if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { - throw std::runtime_error( - "dtype conversion only supported for float32 tensors"); - } - const auto tgt_dtype = - (tensor.dtype() == torch::kFloat32 && dtype.has_value()) - ? get_dtype_from_str(dtype.value()) - : tensor.dtype(); - const auto filetype = [&]() { if (format.has_value()) return format.value(); return get_filetype(path); }(); + if (filetype == "amr-nb") { const auto num_channels = tensor.size(channels_first ? 0 : 1); TORCH_CHECK( num_channels == 1, "amr-nb format only supports single channel audio."); - tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = - get_encodinginfo_for_save(filetype, tgt_dtype, compression); + const auto encoding_info = get_encodinginfo_for_save( + filetype, compression, encoding, bits_per_sample); SoxFormat sf(sox_open_write( path.c_str(), @@ -258,19 +249,17 @@ void save_audio_fileobj( torch::Tensor tensor, int64_t sample_rate, bool channels_first, - c10::optional compression, - std::string filetype, - c10::optional dtype) { + c10::optional& compression, + c10::optional& format, + c10::optional& encoding, + c10::optional& bits_per_sample) { validate_input_tensor(tensor); - if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { + if (!format.has_value()) { throw std::runtime_error( - "dtype conversion only supported for float32 tensors"); + "`format` is required when saving to file object."); } - const auto tgt_dtype = - (tensor.dtype() == torch::kFloat32 && dtype.has_value()) - ? get_dtype_from_str(dtype.value()) - : tensor.dtype(); + const auto filetype = format.value(); if (filetype == "amr-nb") { const auto num_channels = tensor.size(channels_first ? 0 : 1); @@ -278,12 +267,11 @@ void save_audio_fileobj( throw std::runtime_error( "amr-nb format only supports single channel audio."); } - tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = - get_encodinginfo_for_save(filetype, tgt_dtype, compression); + const auto encoding_info = get_encodinginfo_for_save( + filetype, compression, encoding, bits_per_sample); AutoReleaseBuffer buffer; diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h index 6129fac1b7..614153ba52 100644 --- a/torchaudio/csrc/sox/io.h +++ b/torchaudio/csrc/sox/io.h @@ -28,9 +28,10 @@ void save_audio_file( torch::Tensor tensor, int64_t sample_rate, bool channels_first, - c10::optional compression, - c10::optional format, - c10::optional dtype); + c10::optional& compression, + c10::optional& format, + c10::optional& encoding, + c10::optional& bits_per_sample); #ifdef TORCH_API_INCLUDE_EXTENSION_H @@ -51,9 +52,10 @@ void save_audio_fileobj( torch::Tensor tensor, int64_t sample_rate, bool channels_first, - c10::optional compression, - std::string filetype, - c10::optional dtype); + c10::optional& compression, + c10::optional& format, + c10::optional& encoding, + c10::optional& bits_per_sample); #endif // TORCH_API_INCLUDE_EXTENSION_H diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 983b7829fe..17a585b777 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -163,22 +163,32 @@ torch::Tensor convert_to_tensor( const caffe2::TypeMeta dtype, const bool normalize, const bool channels_first) { - auto t = torch::from_blob( - buffer, {num_samples / num_channels, num_channels}, torch::kInt32); - // Note: Tensor created from_blob does not own data but borrwos - // So make sure to create a new copy after processing samples. + torch::Tensor t; + uint64_t dummy; + SOX_SAMPLE_LOCALS; if (normalize || dtype == torch::kFloat32) { - t = t.to(torch::kFloat32); - t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.; + t = torch::empty( + {num_samples / num_channels, num_channels}, torch::kFloat32); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy); + } } else if (dtype == torch::kInt32) { - t = t.clone(); + t = torch::from_blob( + buffer, {num_samples / num_channels, num_channels}, torch::kInt32) + .clone(); } else if (dtype == torch::kInt16) { - t.floor_divide_(1 << 16); - t = t.to(torch::kInt16); + t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy); + } } else if (dtype == torch::kUInt8) { - t.floor_divide_(1 << 24); - t += 128; - t = t.to(torch::kUInt8); + t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy); + } } else { throw std::runtime_error("Unsupported dtype."); } @@ -188,63 +198,155 @@ torch::Tensor convert_to_tensor( return t.contiguous(); } -torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) { - const auto dtype = input_tensor.dtype(); - auto tensor = input_tensor; - if (dtype == torch::kFloat32) { - double multi_pos = 2147483647.; - double multi_neg = -2147483648.; - auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg; - tensor = tensor.to(torch::dtype(torch::kFloat64)); - tensor *= mult; - tensor.clamp_(multi_neg, multi_pos); - tensor = tensor.to(torch::dtype(torch::kInt32)); - } else if (dtype == torch::kInt32) { - // already denormalized - } else if (dtype == torch::kInt16) { - tensor = tensor.to(torch::dtype(torch::kInt32)); - tensor *= ((tensor != 0) * 65536); - } else if (dtype == torch::kUInt8) { - tensor = tensor.to(torch::dtype(torch::kInt32)); - tensor -= 128; - tensor *= 16777216; - } else { - throw std::runtime_error("Unexpected dtype."); - } - return tensor; -} - const std::string get_filetype(const std::string path) { std::string ext = path.substr(path.find_last_of(".") + 1); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); return ext; } -sox_encoding_t get_encoding( - const std::string filetype, - const caffe2::TypeMeta dtype) { - if (filetype == "mp3") - return SOX_ENCODING_MP3; - if (filetype == "flac") - return SOX_ENCODING_FLAC; - if (filetype == "ogg" || filetype == "vorbis") - return SOX_ENCODING_VORBIS; - if (filetype == "wav" || filetype == "amb") { - if (dtype == torch::kUInt8) - return SOX_ENCODING_UNSIGNED; - if (dtype == torch::kInt16) - return SOX_ENCODING_SIGN2; - if (dtype == torch::kInt32) - return SOX_ENCODING_SIGN2; - if (dtype == torch::kFloat32) - return SOX_ENCODING_FLOAT; - throw std::runtime_error("Unsupported dtype."); +namespace { + +std::tuple get_save_encoding_for_wav( + const std::string format, + const c10::optional& encoding, + const c10::optional& bits_per_sample) { + if (!encoding.has_value()) { + if (!bits_per_sample.has_value()) + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + auto val = static_cast(bits_per_sample.value()); + if (val == 8) + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + return std::make_tuple<>(SOX_ENCODING_SIGN2, val); } - if (filetype == "sph") - return SOX_ENCODING_SIGN2; - if (filetype == "amr-nb") - return SOX_ENCODING_AMR_NB; - throw std::runtime_error("Unsupported file type: " + filetype); + if (encoding == ENCODING_PCM_SIGNED) { + if (!bits_per_sample.has_value()) + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + auto val = static_cast(bits_per_sample.value()); + if (val == 8) { + TORCH_WARN_ONCE( + "%s does not support 8-bit signed PCM encoding. Using 16-bit.", + format); + val = 16; + } + return std::make_tuple<>(SOX_ENCODING_SIGN2, val); + } + if (encoding == ENCODING_PCM_UNSIGNED) { + if (!bits_per_sample.has_value()) + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + auto val = static_cast(bits_per_sample.value()); + if (val != 8) + TORCH_WARN_ONCE( + "%s only supports 8-bit for unsigned PCM encoding. Using 8-bit.", + format); + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + } + if (encoding == ENCODING_PCM_FLOAT) { + auto val = static_cast(bits_per_sample.value_or(32)); + if (val != 32) + TORCH_WARN_ONCE( + "%s only supports 32-bit for floating point PCM encoding. Using 32-bit.", + format); + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + } + if (encoding == ENCODING_ULAW) { + auto val = static_cast(bits_per_sample.value_or(8)); + if (val != 8) + TORCH_WARN_ONCE( + "%s only supports 8-bit for mu-law encoding. Using 8-bit.", format); + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + } + if (encoding == ENCODING_ALAW) { + auto val = static_cast(bits_per_sample.value_or(8)); + if (val != 8) + TORCH_WARN_ONCE( + "%s only supports 8-bit for a-law encoding. Using 8-bit.", format); + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + } + std::ostringstream message; + message << format + << " format does not support encoding: " << encoding.value(); + throw std::runtime_error(message.str()); +} + +std::tuple get_save_encoding( + const std::string& format, + const c10::optional& encoding, + const c10::optional& bits_per_sample) { + if (format == "mp3") { + if (encoding.has_value()) { + TORCH_WARN_ONCE("mp3 does not support `encoding` option. Ignoring."); + } + if (bits_per_sample.has_value()) { + TORCH_WARN_ONCE("mp3 does not `bits_per_sample` option. Ignoring."); + } + return std::make_tuple<>(SOX_ENCODING_MP3, 16); + } + if (format == "ogg" || format == "vorbis") { + if (encoding.has_value()) { + TORCH_WARN_ONCE( + "ogg/vorbis does not support `encoding` option. Ignoring."); + } + if (bits_per_sample.has_value()) { + TORCH_WARN_ONCE( + "ogg/vorbis does not `bits_per_sample` option. Ignoring."); + } + return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); + } + if (format == "amr-nb") { + if (encoding.has_value()) { + TORCH_WARN_ONCE("amr-nb does not support `encoding` option. Ignoring."); + } + if (bits_per_sample.has_value()) { + TORCH_WARN_ONCE("amr-nb does not `bits_per_sample` option. Ignoring."); + } + return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); + } + if (format == "wav" || format == "amb") { + return get_save_encoding_for_wav(format, encoding, bits_per_sample); + } + if (format == "flac") { + if (encoding.has_value()) { + TORCH_WARN_ONCE("flac does not support `encoding` option. Ignoring."); + } + unsigned bps = [&]() { + unsigned val = static_cast(bits_per_sample.value_or(24)); + if (val > 24) { + TORCH_WARN_ONCE( + "flac does not support bits_per_sample larger than 24. Using 24."); + val = 24; + } + return val; + }(); + return std::make_tuple<>(SOX_ENCODING_FLAC, bps); + } + if (format == "sph") { + if (!encoding.has_value() || encoding == ENCODING_PCM_SIGNED) { + if (!bits_per_sample.has_value()) + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + auto val = static_cast(bits_per_sample.value()); + return std::make_tuple<>(SOX_ENCODING_SIGN2, val); + } + if (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) { + TORCH_WARN_ONCE( + "sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM"); + auto val = static_cast(bits_per_sample.value_or(16)); + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, val); + } + if (encoding == ENCODING_ULAW) { + auto val = static_cast(bits_per_sample.value_or(8)); + if (val != 8) + TORCH_WARN_ONCE( + "sph only supports 8-bit for mu-law encoding. Using 8-bit."); + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + } + if (encoding == ENCODING_ALAW) { + auto val = static_cast(bits_per_sample.value_or(8)); + return std::make_tuple<>(SOX_ENCODING_ALAW, val); + } + throw std::runtime_error( + "sph format does not support encoding: " + encoding.value()); + } + throw std::runtime_error("Unsupported format: " + format); } unsigned get_precision( @@ -270,14 +372,13 @@ unsigned get_precision( if (filetype == "sph") return 32; if (filetype == "amr-nb") { - TORCH_INTERNAL_ASSERT( - dtype == torch::kInt16, - "When saving to AMR-NB format, the input tensor must be int16 type."); return 16; } throw std::runtime_error("Unsupported file type: " + filetype); } +} // namespace + sox_signalinfo_t get_signalinfo( const torch::Tensor* waveform, const int64_t sample_rate, @@ -325,12 +426,14 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { } sox_encodinginfo_t get_encodinginfo_for_save( - const std::string filetype, - const caffe2::TypeMeta dtype, - c10::optional& compression) { + const std::string& format, + const c10::optional& compression, + const c10::optional& encoding, + const c10::optional& bits_per_sample) { + auto enc = get_save_encoding(format, encoding, bits_per_sample); return sox_encodinginfo_t{ - /*encoding=*/get_encoding(filetype, dtype), - /*bits_per_sample=*/get_precision(filetype, dtype), + /*encoding=*/std::get<0>(enc), + /*bits_per_sample=*/std::get<1>(enc), /*compression=*/compression.value_or(HUGE_VAL), /*reverse_bytes=*/sox_option_default, /*reverse_nibbles=*/sox_option_default, diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index ea2a6a2953..b8423c7625 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -34,6 +34,19 @@ std::vector list_write_formats(); // Utilities for sox_io / sox_effects implementations //////////////////////////////////////////////////////////////////////////////// +const std::string ENCODING_UNKNOWN = "UNKNOWN"; +const std::string ENCODING_PCM_SIGNED = "PCM_S"; +const std::string ENCODING_PCM_UNSIGNED = "PCM_U"; +const std::string ENCODING_PCM_FLOAT = "PCM_F"; +const std::string ENCODING_FLAC = "FLAC"; +const std::string ENCODING_ULAW = "ULAW"; +const std::string ENCODING_ALAW = "ALAW"; +const std::string ENCODING_MP3 = "MP3"; +const std::string ENCODING_VORBIS = "VORBIS"; +const std::string ENCODING_AMR_WB = "AMR_WB"; +const std::string ENCODING_AMR_NB = "AMR_NB"; +const std::string ENCODING_OPUS = "OPUS"; + const std::unordered_set UNSUPPORTED_EFFECTS = {"input", "output", "spectrogram", "noiseprof", "noisered", "splice"}; @@ -93,11 +106,6 @@ torch::Tensor convert_to_tensor( const bool normalize, const bool channels_first); -/// -/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox -/// conversion. -torch::Tensor unnormalize_wav(const torch::Tensor); - /// Extract extension from file path const std::string get_filetype(const std::string path); @@ -113,9 +121,10 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); /// Get sox_encodinginfo_t for saving to file/file object sox_encodinginfo_t get_encodinginfo_for_save( - const std::string filetype, - const caffe2::TypeMeta dtype, - c10::optional& compression); + const std::string& format, + const c10::optional& compression, + const c10::optional& encoding, + const c10::optional& bits_per_sample); #ifdef TORCH_API_INCLUDE_EXTENSION_H From 089d7bc628d795e019d7e0904274b3a251bfc9fc Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 11 Feb 2021 21:40:15 +0000 Subject: [PATCH 02/12] Updater docstrings --- .../backend/sox_io/save_test.py | 7 +- torchaudio/backend/sox_io_backend.py | 145 +++++++++++------- 2 files changed, 95 insertions(+), 57 deletions(-) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index dc06d48ce7..2e826116e1 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -202,6 +202,7 @@ def test_save_mp3(self, test_mode, bit_rate): @nested_params( ["path", "fileobj", "bytesio"], + [8, 16, 24], [ None, 0, @@ -215,9 +216,10 @@ def test_save_mp3(self, test_mode, bit_rate): 8, ], ) - def test_save_flac(self, test_mode, compression_level): + def test_save_flac(self, test_mode, bits_per_sample, compression_level): self.assert_save_consistency( - "flac", compression=compression_level, test_mode=test_mode) + "flac", compression=compression_level, + bits_per_sample=bits_per_sample, test_mode=test_mode) @nested_params( ["path", "fileobj", "bytesio"], @@ -277,6 +279,7 @@ def test_save_amb(self, test_mode, enc_params): @nested_params( ["path", "fileobj", "bytesio"], [ + None, 0, 1, 2, diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 67f9f3e0ed..72b2f4f5dd 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -164,25 +164,48 @@ def save( ): """Save audio data to file. - Note: - Supported formats are; - - * WAV, AMB - - * 32-bit floating-point - * 32-bit signed integer - * 16-bit signed integer - * 8-bit unsigned integer - - * MP3 - * FLAC - * OGG/VORBIS - * SPHERE - * AMR-NB + Supported formats/encodings/bit depths/compression are; + + ``"wav"``, ``"amb"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM (default) + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + ``"mp3"`` + Fixed bit rate (such as 128kHz) and variable bit rate compression. + Default: VBR with high quality. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Different quality level. Default: approx. 112kbps + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM (default) + - 24-bit signed integer PCM + - 32-bit signed integer PCM + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + ``"amr-nb"`` + Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s - To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not - handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` - and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc. + Note: + To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, + ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has + to be linked to ``libsox`` and corresponding codec libraries such as ``libmad`` + or ``libmp3lame`` etc. Args: filepath (str or pathlib.Path): Path to save file. @@ -195,12 +218,16 @@ def save( compression (Optional[float]): Used for formats other than WAV. This corresponds to ``-C`` option of ``sox`` command. - * | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or - | VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. - * | ``FLAC``: compression level. Whole number from ``0`` to ``8``. - | ``8`` is default and highest compression. - * | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression - | and lowest quality. Default: ``3``. + ``"mp3"`` + Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or + VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. + + ``"flac"`` + Whole number from ``0`` to ``8``. ``8`` is default and highest compression. + + ``"ogg"``, ``"vorbis"`` + Number from ``-1`` to ``10``; ``-1`` is the highest compression + and lowest quality. Default: ``3``. See the detail at http://sox.sourceforge.net/soxformat.html. format (str, optional): @@ -211,47 +238,55 @@ def save( When not provided, the value of file extension is used. Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, ``"amb"``, ``"flac"`` and ``"sph"``. - encoding (str, optional): - Changes the encoding for the supported formats, such as ``"wav"``, ``"sph"``. - and ``"amb"``. - Valid values are ``"PCM_S"`` (signed integer Linear PCM), ``"PCM_U"`` - (unsigned integer Linear PCM), ``"PCM_F"`` (floating point PCM), - ``"ULAW"`` (mu-law) and ``"ALAW"`` (a-law). - Different formats support different set of encodings. Providing a value that is not - supported by the format will not cause an error, but will fallback to its default value. + encoding (str, optional): Changes the encoding for the supported formats. + This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` + and ``"sph"``. Valid values are; - If not provided, the default values are picked based on ``format`` and - ``bits_per_sample``; + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) - For ``"wav"`` and ``"amb"`` formats, the default value is; - - ``"PCM_U"`` if ``bits_per_sample=8`` - - ``"PCM_S"`` otherwise - For ``"sph"`` format, the default value is ``"PCM_S"``. + Default values + If not provided, the default value is picked based on ``format`` and ``bits_per_sample``. + + ``"wav"``, ``"amb"`` + - ``"PCM_U"`` if ``bits_per_sample=8`` + - ``"PCM_S"`` otherwise + + ``"sph"`` format; + - the default value is ``"PCM_S"`` - bits_per_sample (int, optional): - Change the bit depth for the supported formats, such as ``"wav"``, ``"flac"``, - ``"sph"``, and ``"amb"``. - Valid values are ``8``, ``16``, ``32`` and ``64``. Different formats support different set of encodings. Providing a value that is not supported by the format will not cause an error, but will fallback to its default value. - If not provided, the default values are picked based on ``format`` and ``"encoding"``; + bits_per_sample (int, optional): Changes the bit depth for the supported formats. + When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the + bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``. + + Default Value; + If not provided, the default values are picked based on ``format`` and ``"encoding"``; - For ``"wav"`` format, the default value is; - - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` - - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. - - ``32`` if ``encoding`` is ``"PCM_F"`` + ``"wav"`` format; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` - For ``"flac"`` format, the default value is ``24``. + ``"flac"`` format; + - the default value is ``24`` - For ``"sph"`` format, the default value is; - - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided. - - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + ``"sph"`` format; + - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided. + - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` - For ``"amb"`` format, the default value is; - - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` - - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. - - ``32`` if ``encoding`` is ``"PCM_F"`` + ``"amb"`` format; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` + + Different formats support different set of encodings. Providing a value that is not + supported by the format will not cause an error, but will fallback to its default value. """ if not torch.jit.is_scripting(): if hasattr(filepath, 'write'): From 72c767c90a0f77437cf3e90347503142fe38b484 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 11 Feb 2021 23:37:50 +0000 Subject: [PATCH 03/12] Add type --- torchaudio/csrc/CMakeLists.txt | 1 + torchaudio/csrc/sox/types.cpp | 57 ++++++++++++++++++++++++++++++++++ torchaudio/csrc/sox/types.h | 31 ++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 torchaudio/csrc/sox/types.cpp create mode 100644 torchaudio/csrc/sox/types.h diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 409d42f57f..dff7a8368a 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -9,6 +9,7 @@ set( sox/utils.cpp sox/effects.cpp sox/effects_chain.cpp + sox/types.cpp ) if(BUILD_TRANSDUCER) diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp new file mode 100644 index 0000000000..82b2c2cdef --- /dev/null +++ b/torchaudio/csrc/sox/types.cpp @@ -0,0 +1,57 @@ +#include + +namespace torchaudio { +namespace sox { + +std::string to_string(Encoding v) { + switch(v) { + case Encoding::UNKNOWN: + return "UNKNOWN"; + case Encoding::PCM_S: + return "PCM_S"; + case Encoding::PCM_U: + return "PCM_U"; + case Encoding::PCM_F: + return "PCM_F"; + case Encoding::FLAC: + return "FLAC"; + case Encoding::ULAW: + return "ULAW"; + case Encoding::ALAW: + return "ALAW"; + case Encoding::MP3: + return "MP3"; + case Encoding::VORBIS: + return "VORBIS"; + case Encoding::AMR_WB: + return "AMR_WB"; + case Encoding::AMR_NB: + return "AMR_NB"; + case Encoding::OPUS: + return "OPUS"; + default: + throw std::runtime_error("Internal Error: unexpected encoding."); + } +} + +Encoding from_string(const c10::optional& encoding) { + if (!encoding.has_value()) + return Encoding::NOT_PROVIDED; + std::string v = encoding.get(); + if (v == "PCM_S") + return Encoding::PCM_S; + if (v == "PCM_U") + return Encoding::PCM_U; + if (v == "PCM_F") + return Encoding::PCM_F; + if (v == "ULAW") + return Encoding::ULAW; + if (v == "ALAW") + return Encoding::ALAW; + std::ostringstream stream; + stream << "Internal Error: unexpected encoding value: " << v; + throw std::runtime_error(stream.str()); +} + +} // namespace sox +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h new file mode 100644 index 0000000000..1d8dfb63ec --- /dev/null +++ b/torchaudio/csrc/sox/types.h @@ -0,0 +1,31 @@ +#ifndef TORCHAUDIO_SOX_TYPES_H +#define TORCHAUDIO_SOX_TYPES_H + +#include + +namespace torchaudio { +namespace sox { + +enum class Encoding { + NOT_PROVIDED, + UNKNOWN, + PCM_S, + PCM_U, + PCM_F, + FLAC, + ULAW, + ALAW, + MP3, + VORBIS, + AMR_WB, + AMR_NB, + OPUS, +}; + +std::string to_string(Encoding v); +Encoding from_string(const c10::optional& encoding); + +} // namespace sox +} // namespace torchaudio + +#endif From a2b2c41a32e2f6c0c5a48ec8094b80106e58f9da Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 02:25:23 +0000 Subject: [PATCH 04/12] Update --- .../backend/sox_io/save_test.py | 2 +- torchaudio/csrc/sox/types.cpp | 61 ++++- torchaudio/csrc/sox/types.h | 36 ++- torchaudio/csrc/sox/utils.cpp | 240 +++++++++--------- 4 files changed, 204 insertions(+), 135 deletions(-) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 2e826116e1..fce5cda89f 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -320,7 +320,7 @@ def test_save_large(self, format, encoding=None, bits_per_sample=None): def test_save_multi_channels(self, num_channels): """`sox_io_backend.save` can save audio with many channels""" self.assert_save_consistency( - "wav", encoding="PCM_U", bits_per_sample=16, + "wav", encoding="PCM_S", bits_per_sample=16, num_channels=num_channels) diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index 82b2c2cdef..434d6a79b2 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -1,17 +1,39 @@ #include namespace torchaudio { -namespace sox { +namespace sox_utils { +Format from_string(const std::string& format) { + if (format == "wav") + return Format::WAV; + if (format == "mp3") + return Format::MP3; + if (format == "flac") + return Format::FLAC; + if (format == "ogg" || format == "vorbis") + return Format::VORBIS; + if (format == "amr-nb") + return Format::AMR_NB; + if (format == "amr-wb") + return Format::AMR_WB; + if (format == "amb") + return Format::AMB; + if (format == "sph") + return Format::SPHERE; + std::ostringstream stream; + stream << "Internal Error: unexpected format value: " << format; + throw std::runtime_error(stream.str()); +} + std::string to_string(Encoding v) { switch(v) { case Encoding::UNKNOWN: return "UNKNOWN"; - case Encoding::PCM_S: + case Encoding::PCM_SIGNED: return "PCM_S"; - case Encoding::PCM_U: + case Encoding::PCM_UNSIGNED: return "PCM_U"; - case Encoding::PCM_F: + case Encoding::PCM_FLOAT: return "PCM_F"; case Encoding::FLAC: return "FLAC"; @@ -34,16 +56,16 @@ std::string to_string(Encoding v) { } } -Encoding from_string(const c10::optional& encoding) { +Encoding from_option(const c10::optional& encoding) { if (!encoding.has_value()) return Encoding::NOT_PROVIDED; - std::string v = encoding.get(); + std::string v = encoding.value(); if (v == "PCM_S") - return Encoding::PCM_S; + return Encoding::PCM_SIGNED; if (v == "PCM_U") - return Encoding::PCM_U; + return Encoding::PCM_UNSIGNED; if (v == "PCM_F") - return Encoding::PCM_F; + return Encoding::PCM_FLOAT; if (v == "ULAW") return Encoding::ULAW; if (v == "ALAW") @@ -53,5 +75,24 @@ Encoding from_string(const c10::optional& encoding) { throw std::runtime_error(stream.str()); } -} // namespace sox +BitDepth from_option(const c10::optional& bit_depth) { + if (!bit_depth.has_value()) + return BitDepth::NOT_PROVIDED; + int64_t v = bit_depth.value(); + if (v == 8) + return BitDepth::B8; + if (v == 16) + return BitDepth::B16; + if (v == 24) + return BitDepth::B24; + if (v == 32) + return BitDepth::B32; + if (v == 64) + return BitDepth::B64; + std::ostringstream stream; + stream << "Internal Error: unexpected bit depth value: " << v; + throw std::runtime_error(stream.str()); +} + +} // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index 1d8dfb63ec..29f3895cee 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -4,14 +4,27 @@ #include namespace torchaudio { -namespace sox { +namespace sox_utils { +enum class Format { + WAV, + MP3, + FLAC, + VORBIS, + AMR_NB, + AMR_WB, + AMB, + SPHERE, +}; + +Format from_string(const std::string& format); + enum class Encoding { NOT_PROVIDED, UNKNOWN, - PCM_S, - PCM_U, - PCM_F, + PCM_SIGNED, + PCM_UNSIGNED, + PCM_FLOAT, FLAC, ULAW, ALAW, @@ -23,9 +36,20 @@ enum class Encoding { }; std::string to_string(Encoding v); -Encoding from_string(const c10::optional& encoding); +Encoding from_option(const c10::optional& encoding); + +enum class BitDepth : unsigned { + NOT_PROVIDED = 0, + B8 = 8, + B16 = 16, + B24 = 24, + B32 = 32, + B64 = 64, +}; -} // namespace sox +BitDepth from_option(const c10::optional& bit_depth); + +} // namespace sox_utils } // namespace torchaudio #endif diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 17a585b777..ef3d414249 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #include namespace torchaudio { @@ -208,145 +209,148 @@ namespace { std::tuple get_save_encoding_for_wav( const std::string format, - const c10::optional& encoding, - const c10::optional& bits_per_sample) { - if (!encoding.has_value()) { - if (!bits_per_sample.has_value()) + const Encoding& encoding, + const BitDepth& bits_per_sample) { + + switch(encoding) { + case Encoding::NOT_PROVIDED: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - auto val = static_cast(bits_per_sample.value()); - if (val == 8) + case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - return std::make_tuple<>(SOX_ENCODING_SIGN2, val); - } - if (encoding == ENCODING_PCM_SIGNED) { - if (!bits_per_sample.has_value()) + default: + return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_SIGNED: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - auto val = static_cast(bits_per_sample.value()); - if (val == 8) { - TORCH_WARN_ONCE( - "%s does not support 8-bit signed PCM encoding. Using 16-bit.", - format); - val = 16; + case BitDepth::B8: + throw std::runtime_error(format + " does not support 8-bit signed PCM encoding."); + default: + return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); } - return std::make_tuple<>(SOX_ENCODING_SIGN2, val); - } - if (encoding == ENCODING_PCM_UNSIGNED) { - if (!bits_per_sample.has_value()) + case Encoding::PCM_UNSIGNED: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - auto val = static_cast(bits_per_sample.value()); - if (val != 8) - TORCH_WARN_ONCE( - "%s only supports 8-bit for unsigned PCM encoding. Using 8-bit.", - format); - return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - } - if (encoding == ENCODING_PCM_FLOAT) { - auto val = static_cast(bits_per_sample.value_or(32)); - if (val != 32) - TORCH_WARN_ONCE( - "%s only supports 32-bit for floating point PCM encoding. Using 32-bit.", - format); - return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); - } - if (encoding == ENCODING_ULAW) { - auto val = static_cast(bits_per_sample.value_or(8)); - if (val != 8) - TORCH_WARN_ONCE( - "%s only supports 8-bit for mu-law encoding. Using 8-bit.", format); - return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for unsigned PCM encoding."); + } + case Encoding::PCM_FLOAT: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B32: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case BitDepth::B64: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); + default: + throw std::runtime_error( + format + " only supports 32-bit or 64-bit for floating-point PCM encoding."); + } + case Encoding::ULAW: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error(format + " only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch(bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + throw std::runtime_error(format + " only supports 8-bit for a-law encoding."); } - if (encoding == ENCODING_ALAW) { - auto val = static_cast(bits_per_sample.value_or(8)); - if (val != 8) - TORCH_WARN_ONCE( - "%s only supports 8-bit for a-law encoding. Using 8-bit.", format); - return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + throw std::runtime_error( + format + " does not support encoding: " + to_string(encoding)); } - std::ostringstream message; - message << format - << " format does not support encoding: " << encoding.value(); - throw std::runtime_error(message.str()); } std::tuple get_save_encoding( const std::string& format, const c10::optional& encoding, const c10::optional& bits_per_sample) { - if (format == "mp3") { - if (encoding.has_value()) { - TORCH_WARN_ONCE("mp3 does not support `encoding` option. Ignoring."); - } - if (bits_per_sample.has_value()) { - TORCH_WARN_ONCE("mp3 does not `bits_per_sample` option. Ignoring."); - } + + const Format fmt = from_string(format); + const Encoding enc = from_option(encoding); + const BitDepth bps = from_option(bits_per_sample); + + switch(fmt) { + case Format::WAV: + case Format::AMB: + return get_save_encoding_for_wav(format, enc, bps); + case Format::MP3: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("mp3 does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error("mp3 does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_MP3, 16); - } - if (format == "ogg" || format == "vorbis") { - if (encoding.has_value()) { - TORCH_WARN_ONCE( - "ogg/vorbis does not support `encoding` option. Ignoring."); - } - if (bits_per_sample.has_value()) { - TORCH_WARN_ONCE( - "ogg/vorbis does not `bits_per_sample` option. Ignoring."); - } + case Format::VORBIS: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("vorbis does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error("vorbis does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); - } - if (format == "amr-nb") { - if (encoding.has_value()) { - TORCH_WARN_ONCE("amr-nb does not support `encoding` option. Ignoring."); - } - if (bits_per_sample.has_value()) { - TORCH_WARN_ONCE("amr-nb does not `bits_per_sample` option. Ignoring."); - } + case Format::AMR_NB: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("amr-nb does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error("amr-nb does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); - } - if (format == "wav" || format == "amb") { - return get_save_encoding_for_wav(format, encoding, bits_per_sample); - } - if (format == "flac") { - if (encoding.has_value()) { - TORCH_WARN_ONCE("flac does not support `encoding` option. Ignoring."); + case Format::FLAC: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("flac does not support `encoding` option."); + switch(bps) { + case BitDepth::B32: + case BitDepth::B64: + throw std::runtime_error("flac does not support `bits_per_sample` larger than 24."); + default: + return std::make_tuple<>(SOX_ENCODING_FLAC, static_cast(bps)); } - unsigned bps = [&]() { - unsigned val = static_cast(bits_per_sample.value_or(24)); - if (val > 24) { - TORCH_WARN_ONCE( - "flac does not support bits_per_sample larger than 24. Using 24."); - val = 24; - } - return val; - }(); - return std::make_tuple<>(SOX_ENCODING_FLAC, bps); - } - if (format == "sph") { - if (!encoding.has_value() || encoding == ENCODING_PCM_SIGNED) { - if (!bits_per_sample.has_value()) + case Format::SPHERE: + switch(enc) { + case Encoding::NOT_PROVIDED: + case Encoding::PCM_SIGNED: + switch(bps) { + case BitDepth::NOT_PROVIDED: return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - auto val = static_cast(bits_per_sample.value()); - return std::make_tuple<>(SOX_ENCODING_SIGN2, val); - } - if (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) { - TORCH_WARN_ONCE( - "sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM"); - auto val = static_cast(bits_per_sample.value_or(16)); - return std::make_tuple<>(SOX_ENCODING_UNSIGNED, val); - } - if (encoding == ENCODING_ULAW) { - auto val = static_cast(bits_per_sample.value_or(8)); - if (val != 8) - TORCH_WARN_ONCE( - "sph only supports 8-bit for mu-law encoding. Using 8-bit."); - return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bps)); + } + case Encoding::PCM_UNSIGNED: + throw std::runtime_error("sph does not support unsigned integer PCM."); + case Encoding::PCM_FLOAT: + throw std::runtime_error("sph does not support floating point PCM."); + case Encoding::ULAW: + switch(bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error("sph only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch(bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + return std::make_tuple<>(SOX_ENCODING_ALAW, static_cast(bps)); + } + default: { + throw std::runtime_error("sph does not support encoding: " + encoding.value()); } - if (encoding == ENCODING_ALAW) { - auto val = static_cast(bits_per_sample.value_or(8)); - return std::make_tuple<>(SOX_ENCODING_ALAW, val); } - throw std::runtime_error( - "sph format does not support encoding: " + encoding.value()); + default: + throw std::runtime_error("Unsupported format: " + format); } - throw std::runtime_error("Unsupported format: " + format); } unsigned get_precision( From 568a815398287a0f00de3ccd0592ece42b388986 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 02:27:23 +0000 Subject: [PATCH 05/12] Fix style --- torchaudio/csrc/sox/types.cpp | 58 ++++---- torchaudio/csrc/sox/types.h | 4 +- torchaudio/csrc/sox/utils.cpp | 261 ++++++++++++++++++---------------- 3 files changed, 168 insertions(+), 155 deletions(-) diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index 434d6a79b2..1e9c803e3b 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -24,35 +24,35 @@ Format from_string(const std::string& format) { stream << "Internal Error: unexpected format value: " << format; throw std::runtime_error(stream.str()); } - + std::string to_string(Encoding v) { - switch(v) { - case Encoding::UNKNOWN: - return "UNKNOWN"; - case Encoding::PCM_SIGNED: - return "PCM_S"; - case Encoding::PCM_UNSIGNED: - return "PCM_U"; - case Encoding::PCM_FLOAT: - return "PCM_F"; - case Encoding::FLAC: - return "FLAC"; - case Encoding::ULAW: - return "ULAW"; - case Encoding::ALAW: - return "ALAW"; - case Encoding::MP3: - return "MP3"; - case Encoding::VORBIS: - return "VORBIS"; - case Encoding::AMR_WB: - return "AMR_WB"; - case Encoding::AMR_NB: - return "AMR_NB"; - case Encoding::OPUS: - return "OPUS"; - default: - throw std::runtime_error("Internal Error: unexpected encoding."); + switch (v) { + case Encoding::UNKNOWN: + return "UNKNOWN"; + case Encoding::PCM_SIGNED: + return "PCM_S"; + case Encoding::PCM_UNSIGNED: + return "PCM_U"; + case Encoding::PCM_FLOAT: + return "PCM_F"; + case Encoding::FLAC: + return "FLAC"; + case Encoding::ULAW: + return "ULAW"; + case Encoding::ALAW: + return "ALAW"; + case Encoding::MP3: + return "MP3"; + case Encoding::VORBIS: + return "VORBIS"; + case Encoding::AMR_WB: + return "AMR_WB"; + case Encoding::AMR_NB: + return "AMR_NB"; + case Encoding::OPUS: + return "OPUS"; + default: + throw std::runtime_error("Internal Error: unexpected encoding."); } } @@ -93,6 +93,6 @@ BitDepth from_option(const c10::optional& bit_depth) { stream << "Internal Error: unexpected bit depth value: " << v; throw std::runtime_error(stream.str()); } - + } // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index 29f3895cee..6382e83b69 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -18,7 +18,7 @@ enum class Format { }; Format from_string(const std::string& format); - + enum class Encoding { NOT_PROVIDED, UNKNOWN, @@ -48,7 +48,7 @@ enum class BitDepth : unsigned { }; BitDepth from_option(const c10::optional& bit_depth); - + } // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index ef3d414249..132723a35c 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -211,65 +211,70 @@ std::tuple get_save_encoding_for_wav( const std::string format, const Encoding& encoding, const BitDepth& bits_per_sample) { - - switch(encoding) { - case Encoding::NOT_PROVIDED: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - default: - return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); - } - case Encoding::PCM_SIGNED: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - case BitDepth::B8: - throw std::runtime_error(format + " does not support 8-bit signed PCM encoding."); - default: - return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); - } - case Encoding::PCM_UNSIGNED: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - default: - throw std::runtime_error( - format + " only supports 8-bit for unsigned PCM encoding."); - } - case Encoding::PCM_FLOAT: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B32: - return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); - case BitDepth::B64: - return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); + switch (encoding) { + case Encoding::NOT_PROVIDED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_SIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case BitDepth::B8: + throw std::runtime_error( + format + " does not support 8-bit signed PCM encoding."); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_UNSIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for unsigned PCM encoding."); + } + case Encoding::PCM_FLOAT: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B32: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case BitDepth::B64: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); + default: + throw std::runtime_error( + format + + " only supports 32-bit or 64-bit for floating-point PCM encoding."); + } + case Encoding::ULAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for a-law encoding."); + } default: throw std::runtime_error( - format + " only supports 32-bit or 64-bit for floating-point PCM encoding."); - } - case Encoding::ULAW: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_ULAW, 8); - default: - throw std::runtime_error(format + " only supports 8-bit for mu-law encoding."); - } - case Encoding::ALAW: - switch(bits_per_sample) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_ALAW, 8); - default: - throw std::runtime_error(format + " only supports 8-bit for a-law encoding."); - } - default: - throw std::runtime_error( - format + " does not support encoding: " + to_string(encoding)); + format + " does not support encoding: " + to_string(encoding)); } } @@ -277,79 +282,87 @@ std::tuple get_save_encoding( const std::string& format, const c10::optional& encoding, const c10::optional& bits_per_sample) { - const Format fmt = from_string(format); const Encoding enc = from_option(encoding); const BitDepth bps = from_option(bits_per_sample); - switch(fmt) { - case Format::WAV: - case Format::AMB: - return get_save_encoding_for_wav(format, enc, bps); - case Format::MP3: - if (enc != Encoding::NOT_PROVIDED) - throw std::runtime_error("mp3 does not support `encoding` option."); - if (bps != BitDepth::NOT_PROVIDED) - throw std::runtime_error("mp3 does not support `bits_per_sample` option."); - return std::make_tuple<>(SOX_ENCODING_MP3, 16); - case Format::VORBIS: - if (enc != Encoding::NOT_PROVIDED) - throw std::runtime_error("vorbis does not support `encoding` option."); - if (bps != BitDepth::NOT_PROVIDED) - throw std::runtime_error("vorbis does not support `bits_per_sample` option."); - return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); - case Format::AMR_NB: - if (enc != Encoding::NOT_PROVIDED) - throw std::runtime_error("amr-nb does not support `encoding` option."); - if (bps != BitDepth::NOT_PROVIDED) - throw std::runtime_error("amr-nb does not support `bits_per_sample` option."); - return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); - case Format::FLAC: - if (enc != Encoding::NOT_PROVIDED) - throw std::runtime_error("flac does not support `encoding` option."); - switch(bps) { - case BitDepth::B32: - case BitDepth::B64: - throw std::runtime_error("flac does not support `bits_per_sample` larger than 24."); - default: - return std::make_tuple<>(SOX_ENCODING_FLAC, static_cast(bps)); - } - case Format::SPHERE: - switch(enc) { - case Encoding::NOT_PROVIDED: - case Encoding::PCM_SIGNED: - switch(bps) { - case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); - default: - return std::make_tuple<>(SOX_ENCODING_SIGN2, static_cast(bps)); + switch (fmt) { + case Format::WAV: + case Format::AMB: + return get_save_encoding_for_wav(format, enc, bps); + case Format::MP3: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("mp3 does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "mp3 does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_MP3, 16); + case Format::VORBIS: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("vorbis does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "vorbis does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); + case Format::AMR_NB: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("amr-nb does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "amr-nb does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); + case Format::FLAC: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("flac does not support `encoding` option."); + switch (bps) { + case BitDepth::B32: + case BitDepth::B64: + throw std::runtime_error( + "flac does not support `bits_per_sample` larger than 24."); + default: + return std::make_tuple<>( + SOX_ENCODING_FLAC, static_cast(bps)); } - case Encoding::PCM_UNSIGNED: - throw std::runtime_error("sph does not support unsigned integer PCM."); - case Encoding::PCM_FLOAT: - throw std::runtime_error("sph does not support floating point PCM."); - case Encoding::ULAW: - switch(bps) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_ULAW, 8); - default: - throw std::runtime_error("sph only supports 8-bit for mu-law encoding."); - } - case Encoding::ALAW: - switch(bps) { - case BitDepth::NOT_PROVIDED: - case BitDepth::B8: - return std::make_tuple<>(SOX_ENCODING_ALAW, 8); - default: - return std::make_tuple<>(SOX_ENCODING_ALAW, static_cast(bps)); + case Format::SPHERE: + switch (enc) { + case Encoding::NOT_PROVIDED: + case Encoding::PCM_SIGNED: + switch (bps) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bps)); + } + case Encoding::PCM_UNSIGNED: + throw std::runtime_error( + "sph does not support unsigned integer PCM."); + case Encoding::PCM_FLOAT: + throw std::runtime_error("sph does not support floating point PCM."); + case Encoding::ULAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + "sph only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_ALAW, static_cast(bps)); + } + default: + throw std::runtime_error( + "sph does not support encoding: " + encoding.value()); } - default: { - throw std::runtime_error("sph does not support encoding: " + encoding.value()); - } - } - default: - throw std::runtime_error("Unsupported format: " + format); + default: + throw std::runtime_error("Unsupported format: " + format); } } From 4e9999568c82cc44be84a87f9a5ae27c2f146b75 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 03:23:21 +0000 Subject: [PATCH 06/12] Default to dtype --- .../backend/sox_io/save_test.py | 18 +++++++++++++++- torchaudio/backend/sox_io_backend.py | 4 ++-- torchaudio/csrc/sox/io.cpp | 4 ++-- torchaudio/csrc/sox/utils.cpp | 21 ++++++++++++++----- torchaudio/csrc/sox/utils.h | 1 + 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index fce5cda89f..dca9c3ad9c 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -44,6 +44,7 @@ def assert_save_consistency( sample_rate: float = 8000, num_channels: int = 2, num_frames: float = 3 * 8000, + src_dtype: str = 'int32', test_mode: str = "path", ): """`save` function produces file that is comparable with `sox` command @@ -93,7 +94,7 @@ def assert_save_consistency( ref_path = self.get_temp_path('3.2.ref.wav') # 1. Generate original wav - data = get_wav_data('int32', num_channels, normalize=False, num_frames=num_frames) + data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to target format with torchaudio @@ -165,6 +166,7 @@ class SaveTest(SaveTestBase): ('PCM_S', 16), ('PCM_S', 32), ('PCM_F', 32), + ('PCM_F', 64), ('ULAW', 8), ('ALAW', 8), ], @@ -174,6 +176,20 @@ def test_save_wav(self, test_mode, enc_params): self.assert_save_consistency( "wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('float32', ), + ('int32', ), + ('int16', ), + ('uint8', ), + ], + ) + def test_save_wav_dtype(self, test_mode, params): + dtype, = params + self.assert_save_consistency( + "wav", src_dtype=dtype, test_mode=test_mode) + @nested_params( ["path", "fileobj", "bytesio"], [ diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 72b2f4f5dd..59201c5c4f 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -170,7 +170,7 @@ def save( - 32-bit floating-point PCM - 32-bit signed integer PCM - 24-bit signed integer PCM - - 16-bit signed integer PCM (default) + - 16-bit signed integer PCM - 8-bit unsigned integer PCM - 8-bit mu-law - 8-bit a-law @@ -189,7 +189,7 @@ def save( ``"sph"`` - 8-bit signed integer PCM - - 16-bit signed integer PCM (default) + - 16-bit signed integer PCM - 24-bit signed integer PCM - 32-bit signed integer PCM - 8-bit mu-law diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index c311155d2b..ae62411c5a 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -135,7 +135,7 @@ void save_audio_file( const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); const auto encoding_info = get_encodinginfo_for_save( - filetype, compression, encoding, bits_per_sample); + filetype, tensor.dtype(), compression, encoding, bits_per_sample); SoxFormat sf(sox_open_write( path.c_str(), @@ -271,7 +271,7 @@ void save_audio_fileobj( const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); const auto encoding_info = get_encodinginfo_for_save( - filetype, compression, encoding, bits_per_sample); + filetype, tensor.dtype(), compression, encoding, bits_per_sample); AutoReleaseBuffer buffer; diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 132723a35c..a199a61ebb 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -209,13 +209,22 @@ namespace { std::tuple get_save_encoding_for_wav( const std::string format, + const caffe2::TypeMeta dtype, const Encoding& encoding, const BitDepth& bits_per_sample) { switch (encoding) { case Encoding::NOT_PROVIDED: switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + if (dtype == torch::kFloat32) + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + if (dtype == torch::kInt32) + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + if (dtype == torch::kInt16) + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + if (dtype == torch::kUInt8) + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + throw std::runtime_error("Internal Error: Unexpected dtype."); case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); default: @@ -225,7 +234,7 @@ std::tuple get_save_encoding_for_wav( case Encoding::PCM_SIGNED: switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); case BitDepth::B8: throw std::runtime_error( format + " does not support 8-bit signed PCM encoding."); @@ -280,6 +289,7 @@ std::tuple get_save_encoding_for_wav( std::tuple get_save_encoding( const std::string& format, + const caffe2::TypeMeta dtype, const c10::optional& encoding, const c10::optional& bits_per_sample) { const Format fmt = from_string(format); @@ -289,7 +299,7 @@ std::tuple get_save_encoding( switch (fmt) { case Format::WAV: case Format::AMB: - return get_save_encoding_for_wav(format, enc, bps); + return get_save_encoding_for_wav(format, dtype, enc, bps); case Format::MP3: if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("mp3 does not support `encoding` option."); @@ -329,7 +339,7 @@ std::tuple get_save_encoding( case Encoding::PCM_SIGNED: switch (bps) { case BitDepth::NOT_PROVIDED: - return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); default: return std::make_tuple<>( SOX_ENCODING_SIGN2, static_cast(bps)); @@ -444,10 +454,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { sox_encodinginfo_t get_encodinginfo_for_save( const std::string& format, + const caffe2::TypeMeta dtype, const c10::optional& compression, const c10::optional& encoding, const c10::optional& bits_per_sample) { - auto enc = get_save_encoding(format, encoding, bits_per_sample); + auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample); return sox_encodinginfo_t{ /*encoding=*/std::get<0>(enc), /*bits_per_sample=*/std::get<1>(enc), diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index b8423c7625..012ba281ea 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -122,6 +122,7 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); /// Get sox_encodinginfo_t for saving to file/file object sox_encodinginfo_t get_encodinginfo_for_save( const std::string& format, + const caffe2::TypeMeta dtype, const c10::optional& compression, const c10::optional& encoding, const c10::optional& bits_per_sample); From e1e1998d7271045cb9ce74e6540137523db5bb0e Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 03:43:44 +0000 Subject: [PATCH 07/12] Update docstring --- torchaudio/backend/sox_io_backend.py | 110 ++++++++++++++------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 59201c5c4f..f9fc4b745e 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -164,49 +164,6 @@ def save( ): """Save audio data to file. - Supported formats/encodings/bit depths/compression are; - - ``"wav"``, ``"amb"`` - - 32-bit floating-point PCM - - 32-bit signed integer PCM - - 24-bit signed integer PCM - - 16-bit signed integer PCM - - 8-bit unsigned integer PCM - - 8-bit mu-law - - 8-bit a-law - - ``"mp3"`` - Fixed bit rate (such as 128kHz) and variable bit rate compression. - Default: VBR with high quality. - - ``"flac"`` - - 8-bit - - 16-bit - - 24-bit (default) - - ``"ogg"``, ``"vorbis"`` - - Different quality level. Default: approx. 112kbps - - ``"sph"`` - - 8-bit signed integer PCM - - 16-bit signed integer PCM - - 24-bit signed integer PCM - - 32-bit signed integer PCM - - 8-bit mu-law - - 8-bit a-law - - 16-bit a-law - - 24-bit a-law - - 32-bit a-law - - ``"amr-nb"`` - Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s - - Note: - To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, - ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has - to be linked to ``libsox`` and corresponding codec libraries such as ``libmad`` - or ``libmp3lame`` etc. - Args: filepath (str or pathlib.Path): Path to save file. This function also handles ``pathlib.Path`` objects, but is annotated @@ -252,14 +209,17 @@ def save( If not provided, the default value is picked based on ``format`` and ``bits_per_sample``. ``"wav"``, ``"amb"`` + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used to determine the default value. + - ``"PCM_U"`` if dtype is ``uint8`` + - ``"PCM_S"`` if dtype is ``int16`` or ``int32` + - ``"PCM_F"`` if dtype is ``float32`` + - ``"PCM_U"`` if ``bits_per_sample=8`` - ``"PCM_S"`` otherwise ``"sph"`` format; - - the default value is ``"PCM_S"`` - - Different formats support different set of encodings. Providing a value that is not - supported by the format will not cause an error, but will fallback to its default value. + - the default value is ``"PCM_S"`` bits_per_sample (int, optional): Changes the bit depth for the supported formats. When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the @@ -268,9 +228,15 @@ def save( Default Value; If not provided, the default values are picked based on ``format`` and ``"encoding"``; - ``"wav"`` format; + ``"wav"``, ``"amb"``; + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used. + - ``8`` if dtype is ``uint8`` + - ``16`` if dtype is ``int16`` + - ``32`` if dtype is ``int32`` or ``float32`` + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` - - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``16`` if ``encoding`` is ``"PCM_S"`` - ``32`` if ``encoding`` is ``"PCM_F"`` ``"flac"`` format; @@ -285,8 +251,50 @@ def save( - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. - ``32`` if ``encoding`` is ``"PCM_F"`` - Different formats support different set of encodings. Providing a value that is not - supported by the format will not cause an error, but will fallback to its default value. + Supported formats/encodings/bit depth/compression are; + + ``"wav"``, ``"amb"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of the input Tensor. + + ``"mp3"`` + Fixed bit rate (such as 128kHz) and variable bit rate compression. + Default: VBR with high quality. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Different quality level. Default: approx. 112kbps + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + ``"amr-nb"`` + Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s + + Note: + To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, + ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has + to be linked to ``libsox`` and corresponding codec libraries such as ``libmad`` + or ``libmp3lame`` etc. """ if not torch.jit.is_scripting(): if hasattr(filepath, 'write'): From 4fe6918320c3bd9ef5944f5616e94e11438198c4 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 03:59:49 +0000 Subject: [PATCH 08/12] Fix style --- torchaudio/backend/sox_io_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index f9fc4b745e..16abc70deb 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -187,12 +187,13 @@ def save( and lowest quality. Default: ``3``. See the detail at http://sox.sourceforge.net/soxformat.html. - format (str, optional): - If provided, overwrite the audio format. This parameter is required in cases where - the ``filepath`` parameter is file-like object or ``filepath`` parameter represents - the path to a file on a local system but missing file extension or has different - extension. - When not provided, the value of file extension is used. + format (str, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is infered from + file extension. If file extension is missing or different, you can specify the + correct format with this argument. + + When ``filepath`` argument is file-like object, this argument is required. + Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, ``"amb"``, ``"flac"`` and ``"sph"``. encoding (str, optional): Changes the encoding for the supported formats. From 01c45cdb30db69e664e7587a77fd698873eb648d Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 04:03:41 +0000 Subject: [PATCH 09/12] fixup! Fix style --- torchaudio/csrc/sox/io.cpp | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index ae62411c5a..8bc520feba 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -14,31 +14,32 @@ namespace { std::string get_encoding(sox_encoding_t encoding) { switch (encoding) { + case SOX_ENCODING_UNKNOWN: + return "UNKNOWN"; case SOX_ENCODING_SIGN2: - return ENCODING_PCM_SIGNED; + return "PCM_S"; case SOX_ENCODING_UNSIGNED: - return ENCODING_PCM_UNSIGNED; + return "PCM_U"; case SOX_ENCODING_FLOAT: - return ENCODING_PCM_FLOAT; + return "PCM_F"; case SOX_ENCODING_FLAC: - return ENCODING_FLAC; + return "FLAC"; case SOX_ENCODING_ULAW: - return ENCODING_ULAW; + return "ULAW"; case SOX_ENCODING_ALAW: - return ENCODING_ALAW; + return "ALAW"; case SOX_ENCODING_MP3: - return ENCODING_MP3; + return "MP3"; case SOX_ENCODING_VORBIS: - return ENCODING_VORBIS; + return "VORBIS"; case SOX_ENCODING_AMR_WB: - return ENCODING_AMR_WB; + return "AMR_WB"; case SOX_ENCODING_AMR_NB: - return ENCODING_AMR_NB; + return "AMR_NB"; case SOX_ENCODING_OPUS: - return ENCODING_OPUS; - case SOX_ENCODING_UNKNOWN: + return "OPUS"; default: - return ENCODING_UNKNOWN; + return "UNKNOWN"; } } From 17de405be9f9ec154d25243e3b70ae0927bf471d Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 04:04:51 +0000 Subject: [PATCH 10/12] fixup --- torchaudio/csrc/sox/utils.h | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index 012ba281ea..97832cea66 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -34,19 +34,6 @@ std::vector list_write_formats(); // Utilities for sox_io / sox_effects implementations //////////////////////////////////////////////////////////////////////////////// -const std::string ENCODING_UNKNOWN = "UNKNOWN"; -const std::string ENCODING_PCM_SIGNED = "PCM_S"; -const std::string ENCODING_PCM_UNSIGNED = "PCM_U"; -const std::string ENCODING_PCM_FLOAT = "PCM_F"; -const std::string ENCODING_FLAC = "FLAC"; -const std::string ENCODING_ULAW = "ULAW"; -const std::string ENCODING_ALAW = "ALAW"; -const std::string ENCODING_MP3 = "MP3"; -const std::string ENCODING_VORBIS = "VORBIS"; -const std::string ENCODING_AMR_WB = "AMR_WB"; -const std::string ENCODING_AMR_NB = "AMR_NB"; -const std::string ENCODING_OPUS = "OPUS"; - const std::unordered_set UNSUPPORTED_EFFECTS = {"input", "output", "spectrogram", "noiseprof", "noisered", "splice"}; From 357a0b13a08965d6226d15198723c9f44052de1a Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 13:18:05 +0000 Subject: [PATCH 11/12] addresses comments --- torchaudio/csrc/sox/types.cpp | 36 +++++++++++++++++++---------------- torchaudio/csrc/sox/types.h | 6 +++--- torchaudio/csrc/sox/utils.cpp | 6 +++--- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index 1e9c803e3b..e3e4f51c42 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -3,7 +3,7 @@ namespace torchaudio { namespace sox_utils { -Format from_string(const std::string& format) { +Format get_format_from_string(const std::string& format) { if (format == "wav") return Format::WAV; if (format == "mp3") @@ -56,7 +56,7 @@ std::string to_string(Encoding v) { } } -Encoding from_option(const c10::optional& encoding) { +Encoding get_encoding_from_option(const c10::optional& encoding) { if (!encoding.has_value()) return Encoding::NOT_PROVIDED; std::string v = encoding.value(); @@ -75,23 +75,27 @@ Encoding from_option(const c10::optional& encoding) { throw std::runtime_error(stream.str()); } -BitDepth from_option(const c10::optional& bit_depth) { +BitDepth get_bit_depth_from_option(const c10::optional& bit_depth) { if (!bit_depth.has_value()) return BitDepth::NOT_PROVIDED; int64_t v = bit_depth.value(); - if (v == 8) - return BitDepth::B8; - if (v == 16) - return BitDepth::B16; - if (v == 24) - return BitDepth::B24; - if (v == 32) - return BitDepth::B32; - if (v == 64) - return BitDepth::B64; - std::ostringstream stream; - stream << "Internal Error: unexpected bit depth value: " << v; - throw std::runtime_error(stream.str()); + switch(v) { + case 8: + return BitDepth::B8; + case 16: + return BitDepth::B16; + case 24: + return BitDepth::B24; + case 32: + return BitDepth::B32; + case 64: + return BitDepth::B64; + default: { + std::ostringstream s; + s << "Internal Error: unexpected bit depth value: " << v; + throw std::runtime_error(s.str()); + } + } } } // namespace sox_utils diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index 6382e83b69..f3ed637478 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -17,7 +17,7 @@ enum class Format { SPHERE, }; -Format from_string(const std::string& format); +Format get_format_from_string(const std::string& format); enum class Encoding { NOT_PROVIDED, @@ -36,7 +36,7 @@ enum class Encoding { }; std::string to_string(Encoding v); -Encoding from_option(const c10::optional& encoding); +Encoding get_encoding_from_option(const c10::optional& encoding); enum class BitDepth : unsigned { NOT_PROVIDED = 0, @@ -47,7 +47,7 @@ enum class BitDepth : unsigned { B64 = 64, }; -BitDepth from_option(const c10::optional& bit_depth); +BitDepth get_bit_depth_from_option(const c10::optional& bit_depth); } // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index a199a61ebb..f49bdde997 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -292,9 +292,9 @@ std::tuple get_save_encoding( const caffe2::TypeMeta dtype, const c10::optional& encoding, const c10::optional& bits_per_sample) { - const Format fmt = from_string(format); - const Encoding enc = from_option(encoding); - const BitDepth bps = from_option(bits_per_sample); + const Format fmt = get_format_from_string(format); + const Encoding enc = get_encoding_from_option(encoding); + const BitDepth bps = get_bit_depth_from_option(bits_per_sample); switch (fmt) { case Format::WAV: From 60a2df6746cc91d20176793841f963c1d2b5de36 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 12 Feb 2021 15:45:13 +0000 Subject: [PATCH 12/12] fix style --- torchaudio/csrc/sox/types.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index e3e4f51c42..51e8e720d6 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -79,7 +79,7 @@ BitDepth get_bit_depth_from_option(const c10::optional& bit_depth) { if (!bit_depth.has_value()) return BitDepth::NOT_PROVIDED; int64_t v = bit_depth.value(); - switch(v) { + switch (v) { case 8: return BitDepth::B8; case 16: