-
Notifications
You must be signed in to change notification settings - Fork 738
Description
With the support for file-like object in I/O function, torchaudio can apply codecs as a form of augmentation.
There are two components required to add such feature to torchaudio.
- Python frontend
The user-facing function. - Extend format for
savefunction.
Currently the audio format that "sox_io" save function supports is limited to certain formats. (wave, mp3, flac, opus, amb, amr-nb, sphere, ogg/vorbis), while the underlying libsox can handle more formats. We need to extend the supported formats.
We welcome the open source community contribution. If you are interested in working on this, please read the following description and leave a comment on which part you would like to work, so other people would not do a duplicated work. If you are interested in working on extended format support on save function, open one PR for one format.
1. Interface
Implementation
The gist of the Python frontend looks like this. Save the given waveform in memory with codec and compression applied, then load it.
# in torchaudio/functional/functional.py
import io
def apply_codec(waveform, sample_rate, format, channels_first=True, compression=None):
bytes = io.BytesIO()
torchaudio.save(bytes, waveform, sample_rate, channels_first, compression=compression, format=format)
bytes.seek(0)
waveform, _ = torchaudio.load(bytes, channels_first=channels_first)
return waveformTesting
For a starter, we need a smoke test that verifies that the function does not crush for a variety of formats. Since we want this function to be able to work in Windows system, where libsox is not available, we need to test for "sox_io" backend and "soundfile" backend with new interface. Therefor, adding a base class where the test logic is implemented, and extending the class for "sox_io" backend and "soundfile" backend would do.
class ApplyCodecTestBase:
def test_codec(self, format, compression):
# run the function
# maybe check the channels (number of frames can change depending on format like mp3)
class ApplyCodecSoxIOTest(ApplyCodecTestBase):
backend = "sox_io"
# parameterize the compression
def test_mp3(self, compression):
self.test_codec("mp3", compression)
class ApplyCodecSoundfileTest(ApplyCodecTestBase):
backend = "soundfile-new"
# similar to the aboveFor the detail of what parameters to be parameterized and how, see the existing test;
audio/test/torchaudio_unittest/sox_io_backend/info_test.py
Lines 56 to 153 in f1d8d1e
| @parameterized.expand(list(itertools.product( | |
| [8000, 16000], | |
| [1, 2], | |
| [96, 128, 160, 192, 224, 256, 320], | |
| )), name_func=name_func) | |
| def test_mp3(self, sample_rate, num_channels, bit_rate): | |
| """`sox_io_backend.info` can check mp3 file correctly""" | |
| duration = 1 | |
| path = self.get_temp_path('data.mp3') | |
| sox_utils.gen_audio_file( | |
| path, sample_rate, num_channels, | |
| compression=bit_rate, duration=duration, | |
| ) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| # mp3 does not preserve the number of samples | |
| # assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels | |
| @parameterized.expand(list(itertools.product( | |
| [8000, 16000], | |
| [1, 2], | |
| list(range(9)), | |
| )), name_func=name_func) | |
| def test_flac(self, sample_rate, num_channels, compression_level): | |
| """`sox_io_backend.info` can check flac file correctly""" | |
| duration = 1 | |
| path = self.get_temp_path('data.flac') | |
| sox_utils.gen_audio_file( | |
| path, sample_rate, num_channels, | |
| compression=compression_level, duration=duration, | |
| ) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels | |
| @parameterized.expand(list(itertools.product( | |
| [8000, 16000], | |
| [1, 2], | |
| [-1, 0, 1, 2, 3, 3.6, 5, 10], | |
| )), name_func=name_func) | |
| def test_vorbis(self, sample_rate, num_channels, quality_level): | |
| """`sox_io_backend.info` can check vorbis file correctly""" | |
| duration = 1 | |
| path = self.get_temp_path('data.vorbis') | |
| sox_utils.gen_audio_file( | |
| path, sample_rate, num_channels, | |
| compression=quality_level, duration=duration, | |
| ) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels | |
| @parameterized.expand(list(itertools.product( | |
| [8000, 16000], | |
| [1, 2], | |
| )), name_func=name_func) | |
| def test_sphere(self, sample_rate, num_channels): | |
| """`sox_io_backend.info` can check sph file correctly""" | |
| duration = 1 | |
| path = self.get_temp_path('data.sph') | |
| sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels | |
| @parameterized.expand(list(itertools.product( | |
| ['float32', 'int32', 'int16', 'uint8'], | |
| [8000, 16000], | |
| [1, 2], | |
| )), name_func=name_func) | |
| def test_amb(self, dtype, sample_rate, num_channels): | |
| """`sox_io_backend.info` can check amb file correctly""" | |
| duration = 1 | |
| path = self.get_temp_path('data.amb') | |
| sox_utils.gen_audio_file( | |
| path, sample_rate, num_channels, | |
| bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels | |
| def test_amr_nb(self): | |
| """`sox_io_backend.info` can check amr-nb file correctly""" | |
| duration = 1 | |
| num_channels = 1 | |
| sample_rate = 8000 | |
| path = self.get_temp_path('data.amr-nb') | |
| sox_utils.gen_audio_file( | |
| path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration) | |
| info = sox_io_backend.info(path) | |
| assert info.sample_rate == sample_rate | |
| assert info.num_frames == sample_rate * duration | |
| assert info.num_channels == num_channels |
2. Extend format support for save function
Implementation
The save function is implemented in C++ backend. The following is the call stack.
torchaudio.save
-> torchaudio.backend.sox_io_backend.save
-> torchaudio.backend.sox_io_backend._save
-> torchaudio/csrc/sox/io.cpp.cpp::save_audio_fileobj
Inside of save_audio_fileobj function, The target signal/encoding are determined by get_signalinfo and get_encodinginfo functions. These functions call get_encoding function to find the corresponding sox_encoding_t type and get_precision function
So as to extend format support, we need to add the correct mapping from format string to sox_encoding_t in get_encoding and the mapping from format string bit depth in get_precision function.
Testing
Checkout this for how the correctness of save function is currently tested. The corresponding test method can be found here.
Add a similar test with necessary parameterization. Note that certain formats have specific restriction (for example, SPHERE and flac format are 24 bit).
Building and testing locally
To work on this, torchaudio needs to be built from source. Use of conda environment (anaconda/miniconda) is highly recommended.
Also, build requires cmake and nightly build version of PyTorch. Refer to pytorch.org for the installation.
To install cmake, do pip install cmake.
Once the environment is setup, the following command will build and run the corresponding tests
BUILD_SOX=1 python setup.py develop
(cd test && pytest torchaudio_unittest/sox_io_backend/save_test.py -v)