Skip to content

Commit 73db470

Browse files
committed
Support in-memory decoding in load/info/apply_effects_file
1 parent f2da586 commit 73db470

File tree

16 files changed

+583
-25
lines changed

16 files changed

+583
-25
lines changed

.circleci/unittest/linux/scripts/install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ if [ "${os}" == Linux ] ; then
4646
# TODO: move this to docker
4747
apt install -y -q libsndfile1
4848
conda install -y -c conda-forge codecov pytest pytest-cov
49-
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy
49+
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy 'requests>=2.20'
5050
else
5151
# Note: installing librosa via pip fail because it will try to compile numba.
5252
conda install -y -c conda-forge codecov pytest pytest-cov 'librosa>=0.8.0' parameterized scipy
53-
pip install kaldi-io SoundFile
53+
pip install kaldi-io SoundFile 'requests>=2.20'
5454
fi

test/torchaudio_unittest/common_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from .case_utils import (
1010
TempDirMixin,
11+
HttpServerMixin,
1112
TestBaseMixin,
1213
PytorchTestCase,
1314
TorchaudioTestCase,

test/torchaudio_unittest/common_utils/case_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import shutil
22
import os.path
3+
import subprocess
34
import tempfile
5+
import time
46
import unittest
57

68
import torch
@@ -40,6 +42,32 @@ def get_temp_path(self, *paths):
4042
return path
4143

4244

45+
class HttpServerMixin(TempDirMixin):
46+
"""Mixin that serves temporary directory as web server
47+
48+
This class creates temporary directory and serve the directory as HTTP service.
49+
The server is up through the execution of all the test suite defined under the subclass.
50+
"""
51+
_proc = None
52+
_port = 8000
53+
54+
@classmethod
55+
def setUpClass(cls):
56+
super().setUpClass()
57+
cls._proc = subprocess.Popen(
58+
['python', '-m', 'http.server', f'{cls._port}'],
59+
cwd=cls.get_base_temp_dir())
60+
time.sleep(1.0)
61+
62+
@classmethod
63+
def tearDownClass(cls):
64+
super().tearDownClass()
65+
cls._proc.kill()
66+
67+
def get_url(self, *route):
68+
return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}'
69+
70+
4371
class TestBaseMixin:
4472
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
4573
dtype = None

test/torchaudio_unittest/soundfile_backend/load_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import tarfile
23
from unittest.mock import patch
34

45
import torch
@@ -299,3 +300,58 @@ def test_wav(self, format_):
299300
@skipIfFormatNotSupported("FLAC")
300301
def test_flac(self, format_):
301302
self._test_format(format_)
303+
304+
305+
@skipIfNoModule("soundfile")
306+
class TestFileObject(TempDirMixin, PytorchTestCase):
307+
def _test_fileobj(self, ext):
308+
"""Loading audio via file-like object works"""
309+
sample_rate = 16000
310+
path = self.get_temp_path(f'test.{ext}')
311+
312+
data = get_wav_data('float32', num_channels=2).numpy().T
313+
soundfile.write(path, data, sample_rate)
314+
expected = soundfile.read(path, dtype='float32')[0].T
315+
316+
with open(path, 'rb') as fileobj:
317+
found, sr = soundfile_backend.load(fileobj)
318+
assert sr == sample_rate
319+
self.assertEqual(expected, found)
320+
321+
def test_fileobj_wav(self):
322+
"""Loading audio via file-like object works"""
323+
self._test_fileobj('wav')
324+
325+
@skipIfFormatNotSupported("FLAC")
326+
def test_fileobj_flac(self):
327+
"""Loading audio via file-like object works"""
328+
self._test_fileobj('flac')
329+
330+
def _test_tarfile(self, ext):
331+
"""Loading audio via file-like object works"""
332+
sample_rate = 16000
333+
audio_file = f'test.{ext}'
334+
audio_path = self.get_temp_path(audio_file)
335+
archive_path = self.get_temp_path('archive.tar.gz')
336+
337+
data = get_wav_data('float32', num_channels=2).numpy().T
338+
soundfile.write(audio_path, data, sample_rate)
339+
expected = soundfile.read(audio_path, dtype='float32')[0].T
340+
341+
with tarfile.TarFile(archive_path, 'w') as tarobj:
342+
tarobj.add(audio_path, arcname=audio_file)
343+
with tarfile.TarFile(archive_path, 'r') as tarobj:
344+
fileobj = tarobj.extractfile(audio_file)
345+
found, sr = soundfile_backend.load(fileobj)
346+
347+
assert sr == sample_rate
348+
self.assertEqual(expected, found)
349+
350+
def test_tarfile_wav(self):
351+
"""Loading audio via file-like object works"""
352+
self._test_tarfile('wav')
353+
354+
@skipIfFormatNotSupported("FLAC")
355+
def test_tarfile_flac(self):
356+
"""Loading audio via file-like object works"""
357+
self._test_tarfile('flac')

test/torchaudio_unittest/sox_io_backend/load_test.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import io
12
import itertools
3+
import tarfile
24

3-
from torchaudio.backend import sox_io_backend
45
from parameterized import parameterized
6+
from torchaudio.backend import sox_io_backend
7+
from torchaudio._internal import module_utils as _mod_utils
58

69
from torchaudio_unittest.common_utils import (
710
TempDirMixin,
11+
HttpServerMixin,
812
PytorchTestCase,
913
skipIfNoExec,
1014
skipIfNoExtension,
15+
skipIfNoModule,
1116
get_asset_path,
1217
get_wav_data,
1318
load_wav,
@@ -19,6 +24,10 @@
1924
)
2025

2126

27+
if _mod_utils.is_module_available("requests"):
28+
import requests
29+
30+
2231
class LoadTestBase(TempDirMixin, PytorchTestCase):
2332
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
2433
"""`sox_io_backend.load` can load wav format correctly.
@@ -369,3 +378,156 @@ def test_mp3(self):
369378
path = get_asset_path("mp3_without_ext")
370379
_, sr = sox_io_backend.load(path, format="mp3")
371380
assert sr == 16000
381+
382+
383+
@skipIfNoExtension
384+
@skipIfNoExec('sox')
385+
class TestFileObject(TempDirMixin, PytorchTestCase):
386+
"""
387+
In this test suite, the result of file-like object input is compared against file path input,
388+
because `load` function is rigrously tested for file path inputs to match libsox's result,
389+
"""
390+
@parameterized.expand([
391+
('wav', None),
392+
('mp3', 128),
393+
('mp3', 320),
394+
('flac', 0),
395+
('flac', 5),
396+
('flac', 8),
397+
('vorbis', -1),
398+
('vorbis', 10),
399+
('amb', None),
400+
])
401+
def test_fileobj(self, ext, compression):
402+
"""Loading audio via file object returns the same result as via file path."""
403+
sample_rate = 16000
404+
format_ = ext if ext in ['mp3'] else None
405+
path = self.get_temp_path(f'test.{ext}')
406+
407+
sox_utils.gen_audio_file(
408+
path, sample_rate, num_channels=2,
409+
compression=compression)
410+
expected, _ = sox_io_backend.load(path)
411+
412+
with open(path, 'rb') as fileobj:
413+
found, sr = sox_io_backend.load(fileobj, format=format_)
414+
415+
assert sr == sample_rate
416+
self.assertEqual(expected, found)
417+
418+
@parameterized.expand([
419+
('wav', None),
420+
('mp3', 128),
421+
('mp3', 320),
422+
('flac', 0),
423+
('flac', 5),
424+
('flac', 8),
425+
('vorbis', -1),
426+
('vorbis', 10),
427+
('amb', None),
428+
])
429+
def test_bytesio(self, ext, compression):
430+
"""Loading audio via BytesIO object returns the same result as via file path."""
431+
sample_rate = 16000
432+
format_ = ext if ext in ['mp3'] else None
433+
path = self.get_temp_path(f'test.{ext}')
434+
435+
sox_utils.gen_audio_file(
436+
path, sample_rate, num_channels=2,
437+
compression=compression)
438+
expected, _ = sox_io_backend.load(path)
439+
440+
with open(path, 'rb') as file_:
441+
fileobj = io.BytesIO(file_.read())
442+
found, sr = sox_io_backend.load(fileobj, format=format_)
443+
444+
assert sr == sample_rate
445+
self.assertEqual(expected, found)
446+
447+
@parameterized.expand([
448+
('wav', None),
449+
('mp3', 128),
450+
('mp3', 320),
451+
('flac', 0),
452+
('flac', 5),
453+
('flac', 8),
454+
('vorbis', -1),
455+
('vorbis', 10),
456+
('amb', None),
457+
])
458+
def test_tarfile(self, ext, compression):
459+
"""Loading compressed audio via file-like object returns the same result as via file path."""
460+
sample_rate = 16000
461+
format_ = ext if ext in ['mp3'] else None
462+
audio_file = f'test.{ext}'
463+
audio_path = self.get_temp_path(audio_file)
464+
archive_path = self.get_temp_path('archive.tar.gz')
465+
466+
sox_utils.gen_audio_file(
467+
audio_path, sample_rate, num_channels=2,
468+
compression=compression)
469+
expected, _ = sox_io_backend.load(audio_path)
470+
471+
with tarfile.TarFile(archive_path, 'w') as tarobj:
472+
tarobj.add(audio_path, arcname=audio_file)
473+
with tarfile.TarFile(archive_path, 'r') as tarobj:
474+
fileobj = tarobj.extractfile(audio_file)
475+
found, sr = sox_io_backend.load(fileobj, format=format_)
476+
477+
assert sr == sample_rate
478+
self.assertEqual(expected, found)
479+
480+
481+
@skipIfNoExtension
482+
@skipIfNoExec('sox')
483+
@skipIfNoModule("requests")
484+
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
485+
@parameterized.expand([
486+
('wav', None),
487+
('mp3', 128),
488+
('mp3', 320),
489+
('flac', 0),
490+
('flac', 5),
491+
('flac', 8),
492+
('vorbis', -1),
493+
('vorbis', 10),
494+
('amb', None),
495+
])
496+
def test_requests(self, ext, compression):
497+
sample_rate = 16000
498+
format_ = ext if ext in ['mp3'] else None
499+
audio_file = f'test.{ext}'
500+
audio_path = self.get_temp_path(audio_file)
501+
502+
sox_utils.gen_audio_file(
503+
audio_path, sample_rate, num_channels=2, compression=compression)
504+
expected, _ = sox_io_backend.load(audio_path)
505+
506+
url = self.get_url(audio_file)
507+
with requests.get(url, stream=True) as resp:
508+
found, sr = sox_io_backend.load(resp.raw, format=format_)
509+
510+
assert sr == sample_rate
511+
self.assertEqual(expected, found)
512+
513+
@parameterized.expand(list(itertools.product(
514+
[0, 1, 10, 100, 1000],
515+
[-1, 1, 10, 100, 1000],
516+
)), name_func=name_func)
517+
def test_frame(self, frame_offset, num_frames):
518+
"""num_frames and frame_offset correctly specify the region of data"""
519+
sample_rate = 8000
520+
audio_file = 'test.wav'
521+
audio_path = self.get_temp_path(audio_file)
522+
523+
original = get_wav_data('float32', num_channels=2)
524+
save_wav(audio_path, original, sample_rate)
525+
frame_end = None if num_frames == -1 else frame_offset + num_frames
526+
expected = original[:, frame_offset:frame_end]
527+
528+
url = self.get_url(audio_file)
529+
with requests.get(url, stream=True) as resp:
530+
found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames)
531+
532+
assert sr == sample_rate
533+
self.assertEqual(expected, found)

torchaudio/backend/_soundfile_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ def load(
8282
``[-1.0, 1.0]``.
8383
8484
Args:
85-
filepath (str or pathlib.Path): Path to audio file.
86-
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
87-
for the consistency with "sox_io" backend, which has a restriction on type annotation
88-
for TorchScript compiler compatiblity.
85+
filepath (path-like object or file-like object):
86+
Source of audio data.
87+
Note:
88+
* This argument is intentionally annotated as ``str`` only,
89+
for the consistency with "sox_io" backend, which has a restriction
90+
on type annotation due to TorchScript compiler compatiblity.
8991
frame_offset (int):
9092
Number of frames to skip before start reading data.
9193
num_frames (int):

torchaudio/backend/sox_io_backend.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
from typing import Tuple, Optional
23

34
import torch
45
from torchaudio._internal import (
56
module_utils as _mod_utils,
67
)
78

9+
import torchaudio
810
from .common import AudioMetaData
911

1012

@@ -82,9 +84,16 @@ def load(
8284
``[-1.0, 1.0]``.
8385
8486
Args:
85-
filepath (str or pathlib.Path):
86-
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
87-
annotated as ``str`` for TorchScript compiler compatibility.
87+
filepath (path-like object or file-like object):
88+
Source of audio data. When the function is not compiled by TorchScript,
89+
(e.g. ``torch.jit.script``), the following types are accepted;
90+
* ``path-like object``: file path
91+
* ``file-like object``: Any object with ``read`` method that returns ``bytes``.
92+
When the function is compiled by TorchScript, only ``str`` type is allowed.
93+
94+
Note:
95+
* This argument is intentionally annotated as ``str`` only due to
96+
TorchScript compiler compatibility.
8897
frame_offset (int):
8998
Number of frames to skip before start reading data.
9099
num_frames (int):
@@ -112,8 +121,13 @@ def load(
112121
integer type, else ``float32`` type. If ``channels_first=True``, it has
113122
``[channel, time]`` else ``[time, channel]``.
114123
"""
115-
# Cast to str in case type is `pathlib.Path`
116-
filepath = str(filepath)
124+
if not torch.jit.is_scripting():
125+
if hasattr(filepath, 'read'):
126+
return torchaudio._torchaudio.load_audio_fileobj(
127+
filepath, frame_offset, num_frames, normalize, channels_first, format)
128+
signal = torch.ops.torchaudio.sox_io_load_audio_file(
129+
os.fspath(filepath), frame_offset, num_frames, normalize, channels_first, format)
130+
return signal.get_tensor(), signal.get_sample_rate()
117131
signal = torch.ops.torchaudio.sox_io_load_audio_file(
118132
filepath, frame_offset, num_frames, normalize, channels_first, format)
119133
return signal.get_tensor(), signal.get_sample_rate()

torchaudio/csrc/pybind.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <torch/extension.h>
2+
#include <torchaudio/csrc/sox/io.h>
23
#include <torchaudio/csrc/sox/legacy.h>
34

5+
46
PYBIND11_MODULE(_torchaudio, m) {
57
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
68
.def(py::init<>())
@@ -94,4 +96,8 @@ PYBIND11_MODULE(_torchaudio, m) {
9496
"get_info",
9597
&torch::audio::get_info,
9698
"Gets information about an audio file");
99+
m.def(
100+
"load_audio_fileobj",
101+
&torchaudio::sox_io::load_audio_fileobj,
102+
"Load audio from file object.");
97103
}

0 commit comments

Comments
 (0)