diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index e7543c762a..72b7c9c88f 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -46,9 +46,9 @@ if [ "${os}" == Linux ] ; then # TODO: move this to docker apt install -y -q libsndfile1 conda install -y -c conda-forge codecov pytest pytest-cov - pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy + pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy 'requests>=2.20' else # Note: installing librosa via pip fail because it will try to compile numba. conda install -y -c conda-forge codecov pytest pytest-cov 'librosa>=0.8.0' parameterized scipy - pip install kaldi-io SoundFile + pip install kaldi-io SoundFile 'requests>=2.20' fi diff --git a/test/torchaudio_unittest/common_utils/__init__.py b/test/torchaudio_unittest/common_utils/__init__.py index 105a054864..cf3e717116 100644 --- a/test/torchaudio_unittest/common_utils/__init__.py +++ b/test/torchaudio_unittest/common_utils/__init__.py @@ -8,6 +8,7 @@ ) from .case_utils import ( TempDirMixin, + HttpServerMixin, TestBaseMixin, PytorchTestCase, TorchaudioTestCase, diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index 2e0a17b5da..9a9d491541 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -1,6 +1,8 @@ import shutil import os.path +import subprocess import tempfile +import time import unittest import torch @@ -40,6 +42,32 @@ def get_temp_path(self, *paths): return path +class HttpServerMixin(TempDirMixin): + """Mixin that serves temporary directory as web server + + This class creates temporary directory and serve the directory as HTTP service. + The server is up through the execution of all the test suite defined under the subclass. + """ + _proc = None + _port = 8000 + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._proc = subprocess.Popen( + ['python', '-m', 'http.server', f'{cls._port}'], + cwd=cls.get_base_temp_dir()) + time.sleep(1.0) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._proc.kill() + + def get_url(self, *route): + return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}' + + class TestBaseMixin: """Mixin to provide consistent way to define device/dtype/backend aware TestCase""" dtype = None diff --git a/test/torchaudio_unittest/soundfile_backend/load_test.py b/test/torchaudio_unittest/soundfile_backend/load_test.py index 4277ac03e9..399266de8f 100644 --- a/test/torchaudio_unittest/soundfile_backend/load_test.py +++ b/test/torchaudio_unittest/soundfile_backend/load_test.py @@ -1,4 +1,5 @@ import os +import tarfile from unittest.mock import patch import torch @@ -299,3 +300,58 @@ def test_wav(self, format_): @skipIfFormatNotSupported("FLAC") def test_flac(self, format_): self._test_format(format_) + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype='float32')[0].T + + with open(path, 'rb') as fileobj: + found, sr = soundfile_backend.load(fileobj) + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj('flac') + + def _test_tarfile(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(audio_path, data, sample_rate) + expected = soundfile.read(audio_path, dtype='float32')[0].T + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = soundfile_backend.load(fileobj) + + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_tarfile_wav(self): + """Loading audio via file-like object works""" + self._test_tarfile('wav') + + @skipIfFormatNotSupported("FLAC") + def test_tarfile_flac(self): + """Loading audio via file-like object works""" + self._test_tarfile('flac') diff --git a/test/torchaudio_unittest/sox_io_backend/load_test.py b/test/torchaudio_unittest/sox_io_backend/load_test.py index 933ab86195..3a4b0ba8fe 100644 --- a/test/torchaudio_unittest/sox_io_backend/load_test.py +++ b/test/torchaudio_unittest/sox_io_backend/load_test.py @@ -1,13 +1,18 @@ +import io import itertools +import tarfile -from torchaudio.backend import sox_io_backend from parameterized import parameterized +from torchaudio.backend import sox_io_backend +from torchaudio._internal import module_utils as _mod_utils from torchaudio_unittest.common_utils import ( TempDirMixin, + HttpServerMixin, PytorchTestCase, skipIfNoExec, skipIfNoExtension, + skipIfNoModule, get_asset_path, get_wav_data, load_wav, @@ -19,6 +24,10 @@ ) +if _mod_utils.is_module_available("requests"): + import requests + + class LoadTestBase(TempDirMixin, PytorchTestCase): def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): """`sox_io_backend.load` can load wav format correctly. @@ -369,3 +378,156 @@ def test_mp3(self): path = get_asset_path("mp3_without_ext") _, sr = sox_io_backend.load(path, format="mp3") assert sr == 16000 + + +@skipIfNoExtension +@skipIfNoExec('sox') +class TestFileObject(TempDirMixin, PytorchTestCase): + """ + In this test suite, the result of file-like object input is compared against file path input, + because `load` function is rigrously tested for file path inputs to match libsox's result, + """ + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_fileobj(self, ext, compression): + """Loading audio via file object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as fileobj: + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio(self, ext, compression): + """Loading audio via BytesIO object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_tarfile(self, ext, compression): + """Loading compressed audio via file-like object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(audio_path) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + +@skipIfNoExtension +@skipIfNoExec('sox') +@skipIfNoModule("requests") +class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_requests(self, ext, compression): + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=2, compression=compression) + expected, _ = sox_io_backend.load(audio_path) + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + found, sr = sox_io_backend.load(resp.raw, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand(list(itertools.product( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + )), name_func=name_func) + def test_frame(self, frame_offset, num_frames): + """num_frames and frame_offset correctly specify the region of data""" + sample_rate = 8000 + audio_file = 'test.wav' + audio_path = self.get_temp_path(audio_file) + + original = get_wav_data('float32', num_channels=2) + save_wav(audio_path, original, sample_rate) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = original[:, frame_offset:frame_end] + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames) + + assert sr == sample_rate + self.assertEqual(expected, found) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 4d403d1849..719224b827 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -82,10 +82,12 @@ def load( ``[-1.0, 1.0]``. Args: - filepath (str or pathlib.Path): Path to audio file. - This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str`` - for the consistency with "sox_io" backend, which has a restriction on type annotation - for TorchScript compiler compatiblity. + filepath (path-like object or file-like object): + Source of audio data. + Note: + * This argument is intentionally annotated as ``str`` only, + for the consistency with "sox_io" backend, which has a restriction + on type annotation due to TorchScript compiler compatiblity. frame_offset (int): Number of frames to skip before start reading data. num_frames (int): diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 0c2a097080..1e6d417cb8 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -1,3 +1,4 @@ +import os from typing import Tuple, Optional import torch @@ -5,6 +6,7 @@ module_utils as _mod_utils, ) +import torchaudio from .common import AudioMetaData @@ -82,9 +84,17 @@ def load( ``[-1.0, 1.0]``. Args: - filepath (str or pathlib.Path): - Path to audio file. This function also handles ``pathlib.Path`` objects, but is - annotated as ``str`` for TorchScript compiler compatibility. + filepath (path-like object or file-like object): + Source of audio data. When the function is not compiled by TorchScript, + (e.g. ``torch.jit.script``), the following types are accepted; + * ``path-like``: file path + * ``file-like``: Object with ``read(size: int) -> bytes`` method, + which returns byte string of at most ``size`` length. + When the function is compiled by TorchScript, only ``str`` type is allowed. + + Note: + * This argument is intentionally annotated as ``str`` only due to + TorchScript compiler compatibility. frame_offset (int): Number of frames to skip before start reading data. num_frames (int): @@ -112,8 +122,13 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. """ - # Cast to str in case type is `pathlib.Path` - filepath = str(filepath) + if not torch.jit.is_scripting(): + if hasattr(filepath, 'read'): + return torchaudio._torchaudio.load_audio_fileobj( + filepath, frame_offset, num_frames, normalize, channels_first, format) + signal = torch.ops.torchaudio.sox_io_load_audio_file( + os.fspath(filepath), frame_offset, num_frames, normalize, channels_first, format) + return signal.get_tensor(), signal.get_sample_rate() signal = torch.ops.torchaudio.sox_io_load_audio_file( filepath, frame_offset, num_frames, normalize, channels_first, format) return signal.get_tensor(), signal.get_sample_rate() diff --git a/torchaudio/csrc/pybind.cpp b/torchaudio/csrc/pybind.cpp index 9f20c40a89..caf9ad9b19 100644 --- a/torchaudio/csrc/pybind.cpp +++ b/torchaudio/csrc/pybind.cpp @@ -1,6 +1,8 @@ #include +#include #include + PYBIND11_MODULE(_torchaudio, m) { py::class_(m, "sox_signalinfo_t") .def(py::init<>()) @@ -94,4 +96,8 @@ PYBIND11_MODULE(_torchaudio, m) { "get_info", &torch::audio::get_info, "Gets information about an audio file"); + m.def( + "load_audio_fileobj", + &torchaudio::sox_io::load_audio_fileobj, + "Load audio from file object."); } diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index f4cee53b93..6e36638a33 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -135,5 +135,88 @@ c10::intrusive_ptr apply_effects_file( tensor, chain.getOutputSampleRate(), channels_first_); } +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::tuple apply_effects_fileobj( + py::object fileobj, + std::vector> effects, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + + // Streaming decoding over file-like object is tricky because libsox operates on FILE pointer. + // The folloing is what `sox` and `play` commands do + // - file input -> FILE pointer + // - URL input -> call wget in suprocess and pipe the data -> FILE pointer + // - stdin -> FILE pointer + // + // We want to, instead, fetch byte strings chunk by chunk, consume them, and discard. + // + // Here is the approach + // 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string + // This will perform header-based format detection, if necessary, then fill the metadata of + // sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the + // buffer of the provided byte string. + // 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it + // starts with unseen data, and append the new data read from the given fileobj. + // This will trick libsox as if it keeps reading from the FILE* continuously. + + // Prepare the buffer used throughout the lifecycle of SoxEffectChain. + // Using std::string and let it manage memory. + // 4096 is minimum size requried by auto_detect_format + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48 + const size_t in_buffer_size = 4096; + std::string in_buffer(in_buffer_size, 'x'); + auto* in_buf = const_cast(in_buffer.data()); + + // Fetch the header, and copy it to the buffer. + auto header = static_cast(static_cast(fileobj.attr("read")(4096))); + memcpy(static_cast(in_buf), + static_cast(const_cast(header.data())), header.length()); + + // Open file (this starts reading the header) + SoxFormat sf(sox_open_mem_read( + in_buf, + in_buffer_size, + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + // In case of streamed data, length can be 0 + validate_input_file(sf, /*check_length=*/false); + + // Prepare output buffer + std::vector out_buffer; + out_buffer.reserve(sf->signal.length); + + // Create and run SoxEffectsChain + const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/sf->encoding, + /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + bool channels_first_ = channels_first.value_or(true); + auto tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + normalize.value_or(true), + channels_first_); + + return std::make_tuple( + tensor, + static_cast(chain.getOutputSampleRate())); +} + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_effects } // namespace torchaudio diff --git a/torchaudio/csrc/sox/effects.h b/torchaudio/csrc/sox/effects.h index 23e0f40680..a2a598b408 100644 --- a/torchaudio/csrc/sox/effects.h +++ b/torchaudio/csrc/sox/effects.h @@ -1,6 +1,10 @@ #ifndef TORCHAUDIO_SOX_EFFECTS_H #define TORCHAUDIO_SOX_EFFECTS_H +#ifdef TORCH_API_INCLUDE_EXTENSION_H +#include +#endif // TORCH_API_INCLUDE_EXTENSION_H + #include #include @@ -22,6 +26,17 @@ c10::intrusive_ptr apply_effects_file( c10::optional& channels_first, c10::optional& format); +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::tuple apply_effects_fileobj( + py::object fileobj, + std::vector> effects, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format); + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_effects } // namespace torchaudio diff --git a/torchaudio/csrc/sox/effects_chain.cpp b/torchaudio/csrc/sox/effects_chain.cpp index c7204bebe8..b397673439 100644 --- a/torchaudio/csrc/sox/effects_chain.cpp +++ b/torchaudio/csrc/sox/effects_chain.cpp @@ -198,7 +198,7 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) { priv->signal = signal; priv->index = 0; if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { - throw std::runtime_error("Failed to add effect: input_tensor"); + throw std::runtime_error("Internal Error: Failed to add effect: input_tensor"); } } @@ -207,7 +207,7 @@ void SoxEffectsChain::addOutputBuffer( SoxEffect e(sox_create_effect(get_tensor_output_handler())); static_cast(e->priv)->buffer = output_buffer; if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { - throw std::runtime_error("Failed to add effect: output_tensor"); + throw std::runtime_error("Internal Error: Failed to add effect: output_tensor"); } } @@ -219,7 +219,7 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) { sox_effect_options(e, 1, opts); if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { std::ostringstream stream; - stream << "Failed to add effect: input " << sf->filename; + stream << "Internal Error: Failed to add effect: input " << sf->filename; throw std::runtime_error(stream.str()); } } @@ -230,7 +230,7 @@ void SoxEffectsChain::addOutputFile(sox_format_t* sf) { static_cast(e->priv)->sf = sf; if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { std::ostringstream stream; - stream << "Failed to add effect: output " << sf->filename; + stream << "Internal Error: Failed to add effect: output " << sf->filename; throw std::runtime_error(stream.str()); } } @@ -266,7 +266,7 @@ void SoxEffectsChain::addEffect(const std::vector effect) { if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { std::ostringstream stream; - stream << "Failed to add effect: \"" << name; + stream << "Internal Error: Failed to add effect: \"" << name; for (size_t i = 1; i < num_args; ++i) { stream << " " << effect[i]; } @@ -283,5 +283,132 @@ int64_t SoxEffectsChain::getOutputSampleRate() { return interm_sig_.rate; } +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +namespace { + +/// helper classes for passing file-like object to SoxEffectChain +struct FileObjInputPriv { + sox_format_t* sf; + py::object* fileobj; + char* buffer; + uint64_t buffer_size; +}; + +/// Callback function to feed byte string +/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278 +int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { + auto priv = static_cast(effp->priv); + auto sf = priv->sf; + auto fileobj = priv->fileobj; + auto buffer = priv->buffer; + auto buffer_size = priv->buffer_size; + + // 1. Refresh the buffer + // + // NOTE: + // Since the underlying FILE* was opened with `fmemopen`, the only way + // libsox detect EOF is reaching the end of the buffer. (null byte won't help) + // Therefore we need to align the content at the end of buffer, otherwise, + // libsox will keep reading the content beyond intended length. + // + // Before: + // + // |<--------consumed------->|<-remaining->| + // |*************************|-------------| + // ^ ftell + // + // After: + // + // |<-offset->|<-remaining->|<--new data-->| + // |**********|-------------|++++++++++++++| + // ^ ftell + + const auto num_consumed = sf->tell_off; + const auto num_remain = buffer_size - num_consumed; + + // 1.1. First, we fetch the data to see if there is data to fill the buffer + py::bytes chunk_ = fileobj->attr("read")(num_consumed); + const auto num_refill = py::len(chunk_); + const auto offset = buffer_size - (num_remain + num_refill); + + if(num_refill > num_consumed) { + std::ostringstream message; + message << "Tried to read up to " << num_consumed << " bytes but, " + << "recieved " << num_refill << " bytes. " + << "The given object does not confirm to read protocol of file object."; + throw std::runtime_error(message.str()); + } + + // 1.2. Move the unconsumed data towards the beginning of buffer. + if (num_remain) { + auto src = static_cast(buffer + num_consumed); + auto dst = static_cast(buffer + offset); + memmove(dst, src, num_remain); + } + + // 1.3. Refill the remaining buffer. + if (num_refill) { + auto chunk = static_cast(chunk_); + auto src = static_cast(const_cast(chunk.c_str())); + auto dst = buffer + offset + num_remain; + memcpy(dst, src, num_refill); + } + + // 1.4. Set the file pointer to the new offset + sf->tell_off = offset; + fseek ((FILE*)sf->fp, offset, SEEK_SET); + + // 2. Perform decoding operation + // The following part is practically same as "input" effect + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48 + + // Ensure that it's a multiple of the number of channels + *osamp -= *osamp % effp->out_signal.channels; + + // Read up to *osamp samples into obuf; + // store the actual number read back to *osamp + *osamp = sox_read(sf, obuf, *osamp); + + return *osamp? SOX_SUCCESS : SOX_EOF; +} + +sox_effect_handler_t* get_fileobj_input_handler() { + static sox_effect_handler_t handler{/*name=*/"input_fileobj_object", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/NULL, + /*drain=*/fileobj_input_drain, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(FileObjInputPriv)}; + return &handler; +} + +} // namespace + +void SoxEffectsChain::addInputFileObj( + sox_format_t* sf, + char* buffer, + uint64_t buffer_size, + py::object* fileobj) { + in_sig_ = sf->signal; + interm_sig_ = in_sig_; + + SoxEffect e(sox_create_effect(get_fileobj_input_handler())); + auto priv = static_cast(e->priv); + priv->sf = sf; + priv->fileobj = fileobj; + priv->buffer = buffer; + priv->buffer_size = buffer_size; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error("Internal Error: Failed to add effect: input fileobj"); + } +} + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_effects_chain } // namespace torchaudio diff --git a/torchaudio/csrc/sox/effects_chain.h b/torchaudio/csrc/sox/effects_chain.h index 1eda53ba0d..b096b3eb3d 100644 --- a/torchaudio/csrc/sox/effects_chain.h +++ b/torchaudio/csrc/sox/effects_chain.h @@ -4,6 +4,10 @@ #include #include +#ifdef TORCH_API_INCLUDE_EXTENSION_H +#include +#endif // TORCH_API_INCLUDE_EXTENSION_H + namespace torchaudio { namespace sox_effects_chain { @@ -33,6 +37,16 @@ class SoxEffectsChain { void addEffect(const std::vector effect); int64_t getOutputNumChannels(); int64_t getOutputSampleRate(); + +#ifdef TORCH_API_INCLUDE_EXTENSION_H + + void addInputFileObj( + sox_format_t* sf, + char* buffer, + uint64_t buffer_size, + py::object* fileobj); + +#endif // TORCH_API_INCLUDE_EXTENSION_H }; } // namespace sox_effects_chain diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index a9d17d30e5..e381be14f8 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -49,13 +49,11 @@ c10::intrusive_ptr get_info( static_cast(sf->signal.length / sf->signal.channels)); } -c10::intrusive_ptr load_audio_file( - const std::string& path, +namespace { + +std::vector> get_effects( c10::optional& frame_offset, - c10::optional& num_frames, - c10::optional& normalize, - c10::optional& channels_first, - c10::optional& format) { + c10::optional& num_frames) { const auto offset = frame_offset.value_or(0); if (offset < 0) { throw std::runtime_error( @@ -79,7 +77,19 @@ c10::intrusive_ptr load_audio_file( os_offset << offset << "s"; effects.emplace_back(std::vector{"trim", os_offset.str()}); } + return effects; +} + +} // namespace +c10::intrusive_ptr load_audio_file( + const std::string& path, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + auto effects = get_effects(frame_offset, num_frames); return torchaudio::sox_effects::apply_effects_file( path, effects, normalize, channels_first, format); } @@ -123,5 +133,21 @@ void save_audio_file( chain.run(); } +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::tuple load_audio_fileobj( + py::object fileobj, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + auto effects = get_effects(frame_offset, num_frames); + return torchaudio::sox_effects::apply_effects_fileobj( + fileobj, effects, normalize, channels_first, format); +} + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_io } // namespace torchaudio diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h index 43d873ee06..d6e5310077 100644 --- a/torchaudio/csrc/sox/io.h +++ b/torchaudio/csrc/sox/io.h @@ -1,6 +1,10 @@ #ifndef TORCHAUDIO_SOX_IO_H #define TORCHAUDIO_SOX_IO_H +#ifdef TORCH_API_INCLUDE_EXTENSION_H +#include +#endif // TORCH_API_INCLUDE_EXTENSION_H + #include #include @@ -38,6 +42,18 @@ void save_audio_file( const c10::intrusive_ptr& signal, const double compression = 0.); +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::tuple load_audio_fileobj( + py::object fileobj, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format); + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_io } // namespace torchaudio diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 35502336a6..44f00084e8 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -92,15 +92,15 @@ SoxFormat::operator sox_format_t*() const noexcept { return fd_; } -void validate_input_file(const SoxFormat& sf) { +void validate_input_file(const SoxFormat& sf, bool check_length) { if (static_cast(sf) == nullptr) { throw std::runtime_error("Error loading audio file: failed to open file."); } if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { throw std::runtime_error("Error loading audio file: unknown encoding."); } - if (sf->signal.length == 0) { - throw std::runtime_error("Error reading audio file: unkown length."); + if (check_length && sf->signal.length == 0) { + throw std::runtime_error("Error reading audio file: unknown length."); } } diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index c9037e9e2f..ee8d1baa66 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -67,7 +67,7 @@ struct SoxFormat { /// /// Verify that input file is found, has known encoding, and not empty -void validate_input_file(const SoxFormat& sf); +void validate_input_file(const SoxFormat& sf, bool check_length=true); /// /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32