Skip to content

Commit f1d8d1e

Browse files
authored
Support file-like object in save func (#1141)
* Support file-like object in save func * Disable CircleCI cache for TP artifacts for cleaner build
1 parent 72b7680 commit f1d8d1e

File tree

14 files changed

+390
-74
lines changed

14 files changed

+390
-74
lines changed

.circleci/config.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ jobs:
411411
paths:
412412
- conda
413413
- env
414-
- third_party/install
415414
- run:
416415
name: Install torchaudio
417416
command: .circleci/unittest/linux/scripts/install.sh
@@ -456,7 +455,6 @@ jobs:
456455
paths:
457456
- conda
458457
- env
459-
- third_party/install
460458
- run:
461459
name: Install torchaudio
462460
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
@@ -569,7 +567,6 @@ jobs:
569567
paths:
570568
- conda
571569
- env
572-
- third_party/install
573570
- run:
574571
name: Install torchaudio
575572
command: .circleci/unittest/linux/scripts/install.sh
@@ -606,7 +603,6 @@ jobs:
606603
paths:
607604
- conda
608605
- env
609-
- third_party/install
610606
- run:
611607
name: Run style check
612608
command: .circleci/unittest/linux/scripts/run_style_checks.sh

.circleci/config.yml.in

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ jobs:
411411
paths:
412412
- conda
413413
- env
414-
- third_party/install
415414
- run:
416415
name: Install torchaudio
417416
command: .circleci/unittest/linux/scripts/install.sh
@@ -456,7 +455,6 @@ jobs:
456455
paths:
457456
- conda
458457
- env
459-
- third_party/install
460458
- run:
461459
name: Install torchaudio
462460
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
@@ -569,7 +567,6 @@ jobs:
569567
paths:
570568
- conda
571569
- env
572-
- third_party/install
573570
- run:
574571
name: Install torchaudio
575572
command: .circleci/unittest/linux/scripts/install.sh
@@ -606,7 +603,6 @@ jobs:
606603
paths:
607604
- conda
608605
- env
609-
- third_party/install
610606
- run:
611607
name: Run style check
612608
command: .circleci/unittest/linux/scripts/run_style_checks.sh

test/torchaudio_unittest/soundfile_backend/save_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import io
12
import itertools
23
from unittest.mock import patch
34

45
from torchaudio._internal import module_utils as _mod_utils
56
from torchaudio.backend import _soundfile_backend as soundfile_backend
6-
from parameterized import parameterized
77

88
from torchaudio_unittest.common_utils import (
99
TempDirMixin,
@@ -209,3 +209,43 @@ def test_channels_first(self, channels_first):
209209
found = load_wav(path)[0]
210210
expected = data if channels_first else data.transpose(1, 0)
211211
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
212+
213+
214+
@skipIfNoModule("soundfile")
215+
class TestFileObject(TempDirMixin, PytorchTestCase):
216+
def _test_fileobj(self, ext):
217+
"""Saving audio to file-like object works"""
218+
sample_rate = 16000
219+
path = self.get_temp_path(f'test.{ext}')
220+
221+
subtype = 'FLOAT' if ext == 'wav' else None
222+
data = get_wav_data('float32', num_channels=2)
223+
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
224+
expected = soundfile.read(path, dtype='float32')[0]
225+
226+
fileobj = io.BytesIO()
227+
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
228+
fileobj.seek(0)
229+
found, sr = soundfile.read(fileobj, dtype='float32')
230+
231+
assert sr == sample_rate
232+
self.assertEqual(expected, found)
233+
234+
def test_fileobj_wav(self):
235+
"""Saving audio via file-like object works"""
236+
self._test_fileobj('wav')
237+
238+
@skipIfFormatNotSupported("FLAC")
239+
def test_fileobj_flac(self):
240+
"""Saving audio via file-like object works"""
241+
self._test_fileobj('flac')
242+
243+
@skipIfFormatNotSupported("NIST")
244+
def test_fileobj_nist(self):
245+
"""Saving audio via file-like object works"""
246+
self._test_fileobj('NIST')
247+
248+
@skipIfFormatNotSupported("OGG")
249+
def test_fileobj_ogg(self):
250+
"""Saving audio via file-like object works"""
251+
self._test_fileobj('OGG')

test/torchaudio_unittest/sox_io_backend/save_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import itertools
23

34
from torchaudio.backend import sox_io_backend
@@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype):
417418
sox_io_backend.save(path, data, 8000)
418419

419420
self.assertEqual(data, expected)
421+
422+
423+
@skipIfNoExtension
424+
@skipIfNoExec('sox')
425+
class TestFileObject(SaveTestBase):
426+
"""
427+
We campare the result of file-like object input against file path input because
428+
`save` function is rigrously tested for file path inputs to match libsox's result,
429+
"""
430+
@parameterized.expand([
431+
('wav', None),
432+
('mp3', 128),
433+
('mp3', 320),
434+
('flac', 0),
435+
('flac', 5),
436+
('flac', 8),
437+
('vorbis', -1),
438+
('vorbis', 10),
439+
('amb', None),
440+
])
441+
def test_fileobj(self, ext, compression):
442+
"""Saving audio to file object returns the same result as via file path."""
443+
sample_rate = 16000
444+
dtype = 'float32'
445+
num_channels = 2
446+
num_frames = 16000
447+
channels_first = True
448+
449+
data = get_wav_data(dtype, num_channels, num_frames=num_frames)
450+
451+
ref_path = self.get_temp_path(f'reference.{ext}')
452+
res_path = self.get_temp_path(f'test.{ext}')
453+
sox_io_backend.save(
454+
ref_path, data, channels_first=channels_first,
455+
sample_rate=sample_rate, compression=compression)
456+
with open(res_path, 'wb') as fileobj:
457+
sox_io_backend.save(
458+
fileobj, data, channels_first=channels_first,
459+
sample_rate=sample_rate, compression=compression, format=ext)
460+
461+
expected_data, _ = sox_io_backend.load(ref_path)
462+
data, sr = sox_io_backend.load(res_path)
463+
464+
assert sample_rate == sr
465+
self.assertEqual(expected_data, data)
466+
467+
@parameterized.expand([
468+
('wav', None),
469+
('mp3', 128),
470+
('mp3', 320),
471+
('flac', 0),
472+
('flac', 5),
473+
('flac', 8),
474+
('vorbis', -1),
475+
('vorbis', 10),
476+
('amb', None),
477+
])
478+
def test_bytesio(self, ext, compression):
479+
"""Saving audio to BytesIO object returns the same result as via file path."""
480+
sample_rate = 16000
481+
dtype = 'float32'
482+
num_channels = 2
483+
num_frames = 16000
484+
channels_first = True
485+
486+
data = get_wav_data(dtype, num_channels, num_frames=num_frames)
487+
488+
ref_path = self.get_temp_path(f'reference.{ext}')
489+
res_path = self.get_temp_path(f'test.{ext}')
490+
sox_io_backend.save(
491+
ref_path, data, channels_first=channels_first,
492+
sample_rate=sample_rate, compression=compression)
493+
fileobj = io.BytesIO()
494+
sox_io_backend.save(
495+
fileobj, data, channels_first=channels_first,
496+
sample_rate=sample_rate, compression=compression, format=ext)
497+
fileobj.seek(0)
498+
with open(res_path, 'wb') as file_:
499+
file_.write(fileobj.read())
500+
501+
expected_data, _ = sox_io_backend.load(ref_path)
502+
data, sr = sox_io_backend.load(res_path)
503+
504+
assert sample_rate == sr
505+
self.assertEqual(expected_data, data)

torchaudio/backend/_soundfile_backend.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def save(
138138
sample_rate: int,
139139
channels_first: bool = True,
140140
compression: Optional[float] = None,
141+
format: Optional[str] = None,
141142
):
142143
"""Save audio data to file.
143144
@@ -168,6 +169,9 @@ def save(
168169
otherwise ``[time, channel]``.
169170
compression (Optional[float]):
170171
Not used. It is here only for interface compatibility reson with "sox_io" backend.
172+
format (str, optional):
173+
Output audio format. This is required when the output audio format cannot be infered from
174+
``filepath``, (such as file extension or ``name`` attribute of the given file object).
171175
"""
172176
if src.ndim != 2:
173177
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
@@ -176,8 +180,13 @@ def save(
176180
'`save` function of "soundfile" backend does not support "compression" parameter. '
177181
"The argument is silently ignored."
178182
)
183+
if hasattr(filepath, 'write'):
184+
if format is None:
185+
raise RuntimeError('`format` is required when saving to file object.')
186+
ext = format
187+
else:
188+
ext = str(filepath).split(".")[-1].lower()
179189

180-
ext = str(filepath).split(".")[-1].lower()
181190
if ext != "wav":
182191
subtype = None
183192
elif src.dtype == torch.uint8:
@@ -193,17 +202,16 @@ def save(
193202
else:
194203
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")
195204

196-
format_ = None
197205
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
198206
# so we extend the extensions manually here
199-
if ext in ["nis", "nist", "sph"]:
200-
format_ = "NIST"
207+
if ext in ["nis", "nist", "sph"] and format is None:
208+
format = "NIST"
201209

202210
if channels_first:
203211
src = src.t()
204212

205213
soundfile.write(
206-
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_
214+
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
207215
)
208216

209217

torchaudio/backend/sox_io_backend.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,33 @@ def load(
134134
return signal.get_tensor(), signal.get_sample_rate()
135135

136136

137+
@torch.jit.unused
138+
def _save(
139+
filepath: str,
140+
src: torch.Tensor,
141+
sample_rate: int,
142+
channels_first: bool = True,
143+
compression: Optional[float] = None,
144+
format: Optional[str] = None,
145+
):
146+
if hasattr(filepath, 'write'):
147+
if format is None:
148+
raise RuntimeError('`format` is required when saving to file object.')
149+
torchaudio._torchaudio.save_audio_fileobj(
150+
filepath, src, sample_rate, channels_first, compression, format)
151+
else:
152+
torch.ops.torchaudio.sox_io_save_audio_file(
153+
os.fspath(filepath), src, sample_rate, channels_first, compression, format)
154+
155+
137156
@_mod_utils.requires_module('torchaudio._torchaudio')
138157
def save(
139158
filepath: str,
140159
src: torch.Tensor,
141160
sample_rate: int,
142161
channels_first: bool = True,
143162
compression: Optional[float] = None,
163+
format: Optional[str] = None,
144164
):
145165
"""Save audio data to file.
146166
@@ -184,23 +204,15 @@ def save(
184204
| and lowest quality. Default: ``3``.
185205
186206
See the detail at http://sox.sourceforge.net/soxformat.html.
207+
format (str, optional):
208+
Output audio format. This is required when the output audio format cannot be infered from
209+
``filepath``, (such as file extension or ``name`` attribute of the given file object).
187210
"""
188-
# Cast to str in case type is `pathlib.Path`
189-
filepath = str(filepath)
190-
if compression is None:
191-
ext = str(filepath).split('.')[-1].lower()
192-
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
193-
compression = 0.
194-
elif ext == 'mp3':
195-
compression = -4.5
196-
elif ext == 'flac':
197-
compression = 8.
198-
elif ext in ['ogg', 'vorbis']:
199-
compression = 3.
200-
else:
201-
raise RuntimeError(f'Unsupported file type: "{ext}"')
202-
signal = torch.classes.torchaudio.TensorSignal(src, sample_rate, channels_first)
203-
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
211+
if not torch.jit.is_scripting():
212+
_save(filepath, src, sample_rate, channels_first, compression, format)
213+
return
214+
torch.ops.torchaudio.sox_io_save_audio_file(
215+
filepath, src, sample_rate, channels_first, compression, format)
204216

205217

206218
@_mod_utils.requires_module('torchaudio._torchaudio')

torchaudio/csrc/pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,8 @@ PYBIND11_MODULE(_torchaudio, m) {
100100
"load_audio_fileobj",
101101
&torchaudio::sox_io::load_audio_fileobj,
102102
"Load audio from file object.");
103+
m.def(
104+
"save_audio_fileobj",
105+
&torchaudio::sox_io::save_audio_fileobj,
106+
"Save audio to file obj.");
103107
}

torchaudio/csrc/sox/effects.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
5959
// Create SoxEffectsChain
6060
const auto dtype = in_tensor.dtype();
6161
torchaudio::sox_effects_chain::SoxEffectsChain chain(
62-
/*input_encoding=*/get_encodinginfo("wav", dtype, 0.),
63-
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
62+
/*input_encoding=*/get_encodinginfo("wav", dtype),
63+
/*output_encoding=*/get_encodinginfo("wav", dtype));
6464

6565
// Prepare output buffer
6666
std::vector<sox_sample_t> out_buffer;
@@ -112,7 +112,7 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
112112
// Create and run SoxEffectsChain
113113
torchaudio::sox_effects_chain::SoxEffectsChain chain(
114114
/*input_encoding=*/sf->encoding,
115-
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
115+
/*output_encoding=*/get_encodinginfo("wav", dtype));
116116

117117
chain.addInputFile(sf);
118118
for (const auto& effect : effects) {
@@ -193,7 +193,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
193193
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
194194
torchaudio::sox_effects_chain::SoxEffectsChain chain(
195195
/*input_encoding=*/sf->encoding,
196-
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
196+
/*output_encoding=*/get_encodinginfo("wav", dtype));
197197
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
198198
for (const auto& effect : effects) {
199199
chain.addEffect(effect);

0 commit comments

Comments
 (0)