Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ if [ "${os}" == Linux ] ; then
# TODO: move this to docker
apt install -y -q libsndfile1
conda install -y -c conda-forge codecov pytest pytest-cov
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy 'requests>=2.20'
else
# Note: installing librosa via pip fail because it will try to compile numba.
conda install -y -c conda-forge codecov pytest pytest-cov 'librosa>=0.8.0' parameterized scipy
pip install kaldi-io SoundFile
pip install kaldi-io SoundFile 'requests>=2.20'
fi
1 change: 1 addition & 0 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from .case_utils import (
TempDirMixin,
HttpServerMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
Expand Down
28 changes: 28 additions & 0 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import shutil
import os.path
import subprocess
import tempfile
import time
import unittest

import torch
Expand Down Expand Up @@ -40,6 +42,32 @@ def get_temp_path(self, *paths):
return path


class HttpServerMixin(TempDirMixin):
"""Mixin that serves temporary directory as web server

This class creates temporary directory and serve the directory as HTTP service.
The server is up through the execution of all the test suite defined under the subclass.
"""
_proc = None
_port = 8000

@classmethod
def setUpClass(cls):
super().setUpClass()
cls._proc = subprocess.Popen(
['python', '-m', 'http.server', f'{cls._port}'],
cwd=cls.get_base_temp_dir())
time.sleep(1.0)

@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls._proc.kill()

def get_url(self, *route):
return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}'


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
Expand Down
56 changes: 56 additions & 0 deletions test/torchaudio_unittest/soundfile_backend/load_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tarfile
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -299,3 +300,58 @@ def test_wav(self, format_):
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T

with open(path, 'rb') as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)

def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')

def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype='float32')[0].T

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj)

assert sr == sample_rate
self.assertEqual(expected, found)

def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile('wav')

@skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile('flac')
164 changes: 163 additions & 1 deletion test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import io
import itertools
import tarfile

from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils

from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
load_wav,
Expand All @@ -19,6 +24,10 @@
)


if _mod_utils.is_module_available("requests"):
import requests


class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Expand Down Expand Up @@ -369,3 +378,156 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)

with open(path, 'rb') as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)

with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_tarfile(self, ext, compression):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(audio_path)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)


@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_requests(self, ext, compression):
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2, compression=compression)
expected, _ = sox_io_backend.load(audio_path)

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000
audio_file = 'test.wav'
audio_path = self.get_temp_path(audio_file)

original = get_wav_data('float32', num_channels=2)
save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end]

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames)

assert sr == sample_rate
self.assertEqual(expected, found)
10 changes: 6 additions & 4 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def load(
``[-1.0, 1.0]``.

Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
Expand Down
25 changes: 20 additions & 5 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import Tuple, Optional

import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)

import torchaudio
from .common import AudioMetaData


Expand Down Expand Up @@ -82,9 +84,17 @@ def load(
``[-1.0, 1.0]``.

Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.

Note:
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
Expand Down Expand Up @@ -112,8 +122,13 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format)
signal = torch.ops.torchaudio.sox_io_load_audio_file(
os.fspath(filepath), frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>


PYBIND11_MODULE(_torchaudio, m) {
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
.def(py::init<>())
Expand Down Expand Up @@ -94,4 +96,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
}
Loading