Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import itertools

import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized

Expand All @@ -24,7 +25,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`sox_io_backend.save` can save wav format."""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
sox_io_backend.save(path, expected, sample_rate)
sox_io_backend.save(path, expected, sample_rate, dtype=None)
found, sr = load_wav(path)
assert sample_rate == sr
self.assertEqual(found, expected)
Expand Down Expand Up @@ -68,7 +69,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_io_backend.save(
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate, dtype=None)
# 2.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path, wav_path)
# 2.3. Load
Expand Down Expand Up @@ -99,7 +100,7 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level, dtype=None)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
Expand Down Expand Up @@ -132,7 +133,7 @@ def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend.save(
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None)
# 2.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path, wav_path)
# 2.3. Load
Expand Down Expand Up @@ -184,7 +185,7 @@ def assert_sphere(self, sample_rate, num_channels, duration):
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate)
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate, dtype=None)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
Expand Down Expand Up @@ -216,7 +217,7 @@ def assert_amb(self, dtype, sample_rate, num_channels, duration):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
Expand Down Expand Up @@ -248,7 +249,7 @@ def assert_amr_nb(self, duration):
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
Expand Down Expand Up @@ -389,7 +390,7 @@ def test_channels_first(self, channels_first):
path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
path, data, 8000, channels_first=channels_first, dtype=None)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
Expand All @@ -402,7 +403,7 @@ def test_noncontiguous(self, dtype):
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000)
sox_io_backend.save(path, expected, 8000, dtype=None)
found = load_wav(path)[0]
self.assertEqual(found, expected)

Expand All @@ -415,10 +416,24 @@ def test_tensor_preserve(self, dtype):
expected = get_wav_data(dtype, 4)[::2, ::2]

data = expected.clone()
sox_io_backend.save(path, data, 8000)
sox_io_backend.save(path, data, 8000, dtype=None)

self.assertEqual(data, expected)

@parameterized.expand([
('float32', torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)),
('int32', torch.tensor([-2147483648, -1073741824, 0, 1073741824, 2147483647]).to(torch.int32)),
('int16', torch.tensor([-32768, -16384, 0, 16384, 32767]).to(torch.int16)),
('uint8', torch.tensor([0, 64, 128, 192, 255]).to(torch.uint8)),
])
def test_dtype_conversion(self, dtype, expected):
"""`save` performs dtype conversion on float32 src tensors only."""
path = self.get_temp_path("data.wav")
data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1)
sox_io_backend.save(path, data, 8000, dtype=dtype)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected.view(-1, 1))


@skipIfNoExtension
@skipIfNoExec('sox')
Expand Down Expand Up @@ -452,11 +467,11 @@ def test_fileobj(self, ext, compression):
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
sample_rate=sample_rate, compression=compression, dtype=None)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)
Expand Down Expand Up @@ -489,11 +504,11 @@ def test_bytesio(self, ext, compression):
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
sample_rate=sample_rate, compression=compression, dtype=None)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())
Expand Down
21 changes: 17 additions & 4 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Tuple, Optional

import torch
Expand Down Expand Up @@ -178,15 +179,16 @@ def _save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format)
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format)
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand All @@ -197,6 +199,7 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -243,12 +246,22 @@ def save(
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional)
Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
"""
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format)
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
return
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format)
filepath, src, sample_rate, channels_first, compression, format, dtype)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
28 changes: 22 additions & 6 deletions torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,19 @@ void save_audio_file(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format) {
c10::optional<std::string> format,
c10::optional<std::string> dtype) {
validate_input_tensor(tensor);

auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();

const auto filetype = [&]() {
if (format.has_value())
Expand All @@ -124,8 +133,7 @@ void save_audio_file(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);

SoxFormat sf(sox_open_write(
path.c_str(),
Expand Down Expand Up @@ -239,10 +247,19 @@ void save_audio_fileobj(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
std::string filetype) {
std::string filetype,
c10::optional<std::string> dtype) {
validate_input_tensor(tensor);

auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();

if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
Expand All @@ -253,8 +270,7 @@ void save_audio_fileobj(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);

AutoReleaseBuffer buffer;

Expand Down
6 changes: 4 additions & 2 deletions torchaudio/csrc/sox/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ void save_audio_file(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format);
c10::optional<std::string> format,
c10::optional<std::string> dtype);

#ifdef TORCH_API_INCLUDE_EXTENSION_H

Expand All @@ -68,7 +69,8 @@ void save_audio_fileobj(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
std::string filetype);
std::string filetype,
c10::optional<std::string> dtype);

#endif // TORCH_API_INCLUDE_EXTENSION_H

Expand Down
18 changes: 18 additions & 0 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ caffe2::TypeMeta get_dtype(
return c10::scalarTypeToTypeMeta(dtype);
}

caffe2::TypeMeta get_dtype_from_str(const std::string dtype) {
const auto tgt_dtype = [&]() {
if (dtype == "uint8")
return torch::kUInt8;
else if (dtype == "int16")
return torch::kInt16;
else if (dtype == "int32")
return torch::kInt32;
else if (dtype == "float32")
return torch::kFloat32;
else if (dtype == "float64")
return torch::kFloat64;
else
throw std::runtime_error("Unsupported dtype");
}();
return c10::scalarTypeToTypeMeta(tgt_dtype);
}

torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/csrc/sox/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision);

caffe2::TypeMeta get_dtype_from_str(const std::string dtype);

///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
Expand Down