Skip to content

Commit 674a71d

Browse files
author
Caroline Chen
authored
Add target dtype argument to save function for sox backend (#1204)
1 parent 47d97e3 commit 674a71d

File tree

6 files changed

+92
-26
lines changed

6 files changed

+92
-26
lines changed

test/torchaudio_unittest/sox_io_backend/save_test.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import itertools
33

4+
import torch
45
from torchaudio.backend import sox_io_backend
56
from parameterized import parameterized
67

@@ -24,7 +25,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
2425
"""`sox_io_backend.save` can save wav format."""
2526
path = self.get_temp_path('data.wav')
2627
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
27-
sox_io_backend.save(path, expected, sample_rate)
28+
sox_io_backend.save(path, expected, sample_rate, dtype=None)
2829
found, sr = load_wav(path)
2930
assert sample_rate == sr
3031
self.assertEqual(found, expected)
@@ -68,7 +69,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
6869
save_wav(src_path, data, sample_rate)
6970
# 2.1. Convert the original wav to mp3 with torchaudio
7071
sox_io_backend.save(
71-
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
72+
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate, dtype=None)
7273
# 2.2. Convert the mp3 to wav with Sox
7374
sox_utils.convert_audio_file(mp3_path, wav_path)
7475
# 2.3. Load
@@ -99,7 +100,7 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
99100
save_wav(src_path, data, sample_rate)
100101
# 2.1. Convert the original wav to flac with torchaudio
101102
sox_io_backend.save(
102-
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
103+
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level, dtype=None)
103104
# 2.2. Convert the flac to wav with Sox
104105
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
105106
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
@@ -132,7 +133,7 @@ def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
132133
save_wav(src_path, data, sample_rate)
133134
# 2.1. Convert the original wav to vorbis with torchaudio
134135
sox_io_backend.save(
135-
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
136+
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None)
136137
# 2.2. Convert the vorbis to wav with Sox
137138
sox_utils.convert_audio_file(vbs_path, wav_path)
138139
# 2.3. Load
@@ -184,7 +185,7 @@ def assert_sphere(self, sample_rate, num_channels, duration):
184185
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
185186
save_wav(src_path, data, sample_rate)
186187
# 2.1. Convert the original wav to sph with torchaudio
187-
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate)
188+
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate, dtype=None)
188189
# 2.2. Convert the sph to wav with Sox
189190
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
190191
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
@@ -216,7 +217,7 @@ def assert_amb(self, dtype, sample_rate, num_channels, duration):
216217
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
217218
save_wav(src_path, data, sample_rate)
218219
# 2.1. Convert the original wav to amb with torchaudio
219-
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
220+
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
220221
# 2.2. Convert the amb to wav with Sox
221222
sox_utils.convert_audio_file(amb_path, wav_path)
222223
# 2.3. Load
@@ -248,7 +249,7 @@ def assert_amr_nb(self, duration):
248249
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
249250
save_wav(src_path, data, sample_rate)
250251
# 2.1. Convert the original wav to amr_nb with torchaudio
251-
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
252+
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
252253
# 2.2. Convert the amr_nb to wav with Sox
253254
sox_utils.convert_audio_file(amr_path, wav_path)
254255
# 2.3. Load
@@ -389,7 +390,7 @@ def test_channels_first(self, channels_first):
389390
path = self.get_temp_path('data.wav')
390391
data = get_wav_data('int32', 2, channels_first=channels_first)
391392
sox_io_backend.save(
392-
path, data, 8000, channels_first=channels_first)
393+
path, data, 8000, channels_first=channels_first, dtype=None)
393394
found = load_wav(path)[0]
394395
expected = data if channels_first else data.transpose(1, 0)
395396
self.assertEqual(found, expected)
@@ -402,7 +403,7 @@ def test_noncontiguous(self, dtype):
402403
path = self.get_temp_path('data.wav')
403404
expected = get_wav_data(dtype, 4)[::2, ::2]
404405
assert not expected.is_contiguous()
405-
sox_io_backend.save(path, expected, 8000)
406+
sox_io_backend.save(path, expected, 8000, dtype=None)
406407
found = load_wav(path)[0]
407408
self.assertEqual(found, expected)
408409

@@ -415,10 +416,24 @@ def test_tensor_preserve(self, dtype):
415416
expected = get_wav_data(dtype, 4)[::2, ::2]
416417

417418
data = expected.clone()
418-
sox_io_backend.save(path, data, 8000)
419+
sox_io_backend.save(path, data, 8000, dtype=None)
419420

420421
self.assertEqual(data, expected)
421422

423+
@parameterized.expand([
424+
('float32', torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)),
425+
('int32', torch.tensor([-2147483648, -1073741824, 0, 1073741824, 2147483647]).to(torch.int32)),
426+
('int16', torch.tensor([-32768, -16384, 0, 16384, 32767]).to(torch.int16)),
427+
('uint8', torch.tensor([0, 64, 128, 192, 255]).to(torch.uint8)),
428+
])
429+
def test_dtype_conversion(self, dtype, expected):
430+
"""`save` performs dtype conversion on float32 src tensors only."""
431+
path = self.get_temp_path("data.wav")
432+
data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1)
433+
sox_io_backend.save(path, data, 8000, dtype=dtype)
434+
found = load_wav(path, normalize=False)[0]
435+
self.assertEqual(found, expected.view(-1, 1))
436+
422437

423438
@skipIfNoExtension
424439
@skipIfNoExec('sox')
@@ -452,11 +467,11 @@ def test_fileobj(self, ext, compression):
452467
res_path = self.get_temp_path(f'test.{ext}')
453468
sox_io_backend.save(
454469
ref_path, data, channels_first=channels_first,
455-
sample_rate=sample_rate, compression=compression)
470+
sample_rate=sample_rate, compression=compression, dtype=None)
456471
with open(res_path, 'wb') as fileobj:
457472
sox_io_backend.save(
458473
fileobj, data, channels_first=channels_first,
459-
sample_rate=sample_rate, compression=compression, format=ext)
474+
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
460475

461476
expected_data, _ = sox_io_backend.load(ref_path)
462477
data, sr = sox_io_backend.load(res_path)
@@ -489,11 +504,11 @@ def test_bytesio(self, ext, compression):
489504
res_path = self.get_temp_path(f'test.{ext}')
490505
sox_io_backend.save(
491506
ref_path, data, channels_first=channels_first,
492-
sample_rate=sample_rate, compression=compression)
507+
sample_rate=sample_rate, compression=compression, dtype=None)
493508
fileobj = io.BytesIO()
494509
sox_io_backend.save(
495510
fileobj, data, channels_first=channels_first,
496-
sample_rate=sample_rate, compression=compression, format=ext)
511+
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
497512
fileobj.seek(0)
498513
with open(res_path, 'wb') as file_:
499514
file_.write(fileobj.read())

torchaudio/backend/sox_io_backend.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Tuple, Optional
34

45
import torch
@@ -178,15 +179,16 @@ def _save(
178179
channels_first: bool = True,
179180
compression: Optional[float] = None,
180181
format: Optional[str] = None,
182+
dtype: Optional[str] = None,
181183
):
182184
if hasattr(filepath, 'write'):
183185
if format is None:
184186
raise RuntimeError('`format` is required when saving to file object.')
185187
torchaudio._torchaudio.save_audio_fileobj(
186-
filepath, src, sample_rate, channels_first, compression, format)
188+
filepath, src, sample_rate, channels_first, compression, format, dtype)
187189
else:
188190
torch.ops.torchaudio.sox_io_save_audio_file(
189-
os.fspath(filepath), src, sample_rate, channels_first, compression, format)
191+
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
190192

191193

192194
@_mod_utils.requires_module('torchaudio._torchaudio')
@@ -197,6 +199,7 @@ def save(
197199
channels_first: bool = True,
198200
compression: Optional[float] = None,
199201
format: Optional[str] = None,
202+
dtype: Optional[str] = None,
200203
):
201204
"""Save audio data to file.
202205
@@ -243,12 +246,22 @@ def save(
243246
format (str, optional):
244247
Output audio format. This is required when the output audio format cannot be infered from
245248
``filepath``, (such as file extension or ``name`` attribute of the given file object).
249+
dtype (str, optional)
250+
Output tensor dtype.
251+
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
252+
``dtype=None`` means no conversion is performed.
253+
``dtype`` parameter is only effective for ``float32`` Tensor.
246254
"""
255+
if src.dtype == torch.float32 and dtype is None:
256+
warnings.warn(
257+
'`dtype` default value will be changed to `int16` in 0.9 release.'
258+
'Specify `dtype` to suppress this warning.'
259+
)
247260
if not torch.jit.is_scripting():
248-
_save(filepath, src, sample_rate, channels_first, compression, format)
261+
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
249262
return
250263
torch.ops.torchaudio.sox_io_save_audio_file(
251-
filepath, src, sample_rate, channels_first, compression, format)
264+
filepath, src, sample_rate, channels_first, compression, format, dtype)
252265

253266

254267
@_mod_utils.requires_module('torchaudio._torchaudio')

torchaudio/csrc/sox/io.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,19 @@ void save_audio_file(
107107
int64_t sample_rate,
108108
bool channels_first,
109109
c10::optional<double> compression,
110-
c10::optional<std::string> format) {
110+
c10::optional<std::string> format,
111+
c10::optional<std::string> dtype) {
111112
validate_input_tensor(tensor);
112113

113114
auto signal = TensorSignal(tensor, sample_rate, channels_first);
115+
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
116+
throw std::runtime_error(
117+
"dtype conversion only supported for float32 tensors");
118+
}
119+
const auto tgt_dtype =
120+
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
121+
? get_dtype_from_str(dtype.value())
122+
: tensor.dtype();
114123

115124
const auto filetype = [&]() {
116125
if (format.has_value())
@@ -124,8 +133,7 @@ void save_audio_file(
124133
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
125134
}
126135
const auto signal_info = get_signalinfo(&signal, filetype);
127-
const auto encoding_info =
128-
get_encodinginfo(filetype, tensor.dtype(), compression);
136+
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
129137

130138
SoxFormat sf(sox_open_write(
131139
path.c_str(),
@@ -239,10 +247,19 @@ void save_audio_fileobj(
239247
int64_t sample_rate,
240248
bool channels_first,
241249
c10::optional<double> compression,
242-
std::string filetype) {
250+
std::string filetype,
251+
c10::optional<std::string> dtype) {
243252
validate_input_tensor(tensor);
244253

245254
auto signal = TensorSignal(tensor, sample_rate, channels_first);
255+
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
256+
throw std::runtime_error(
257+
"dtype conversion only supported for float32 tensors");
258+
}
259+
const auto tgt_dtype =
260+
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
261+
? get_dtype_from_str(dtype.value())
262+
: tensor.dtype();
246263

247264
if (filetype == "amr-nb") {
248265
const auto num_channels = tensor.size(channels_first ? 0 : 1);
@@ -253,8 +270,7 @@ void save_audio_fileobj(
253270
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
254271
}
255272
const auto signal_info = get_signalinfo(&signal, filetype);
256-
const auto encoding_info =
257-
get_encodinginfo(filetype, tensor.dtype(), compression);
273+
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
258274

259275
AutoReleaseBuffer buffer;
260276

torchaudio/csrc/sox/io.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ void save_audio_file(
4646
int64_t sample_rate,
4747
bool channels_first,
4848
c10::optional<double> compression,
49-
c10::optional<std::string> format);
49+
c10::optional<std::string> format,
50+
c10::optional<std::string> dtype);
5051

5152
#ifdef TORCH_API_INCLUDE_EXTENSION_H
5253

@@ -68,7 +69,8 @@ void save_audio_fileobj(
6869
int64_t sample_rate,
6970
bool channels_first,
7071
c10::optional<double> compression,
71-
std::string filetype);
72+
std::string filetype,
73+
c10::optional<std::string> dtype);
7274

7375
#endif // TORCH_API_INCLUDE_EXTENSION_H
7476

torchaudio/csrc/sox/utils.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,24 @@ caffe2::TypeMeta get_dtype(
156156
return c10::scalarTypeToTypeMeta(dtype);
157157
}
158158

159+
caffe2::TypeMeta get_dtype_from_str(const std::string dtype) {
160+
const auto tgt_dtype = [&]() {
161+
if (dtype == "uint8")
162+
return torch::kUInt8;
163+
else if (dtype == "int16")
164+
return torch::kInt16;
165+
else if (dtype == "int32")
166+
return torch::kInt32;
167+
else if (dtype == "float32")
168+
return torch::kFloat32;
169+
else if (dtype == "float64")
170+
return torch::kFloat64;
171+
else
172+
throw std::runtime_error("Unsupported dtype");
173+
}();
174+
return c10::scalarTypeToTypeMeta(tgt_dtype);
175+
}
176+
159177
torch::Tensor convert_to_tensor(
160178
sox_sample_t* buffer,
161179
const int32_t num_samples,

torchaudio/csrc/sox/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ caffe2::TypeMeta get_dtype(
8585
const sox_encoding_t encoding,
8686
const unsigned precision);
8787

88+
caffe2::TypeMeta get_dtype_from_str(const std::string dtype);
89+
8890
///
8991
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
9092
/// NOTE: This function might modify the values in the input buffer to

0 commit comments

Comments
 (0)