Skip to content

Commit 70763f4

Browse files
committed
Add sox_utils module
1 parent 131e48b commit 70763f4

File tree

12 files changed

+231
-30
lines changed

12 files changed

+231
-30
lines changed

test/utils/__init__.py

Whitespace-only changes.

test/utils/test_sox_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from torchaudio.utils import sox_utils
2+
3+
from ..common_utils import (
4+
PytorchTestCase,
5+
skipIfNoExtension,
6+
)
7+
8+
9+
@skipIfNoExtension
10+
class TestSoxUtils(PytorchTestCase):
11+
"""Smoke tests for sox_util module"""
12+
def test_set_seed(self):
13+
"""`set_seed` does not crush"""
14+
sox_utils.set_seed(0)
15+
16+
def test_set_verbosity(self):
17+
"""`set_verbosity` does not crush"""
18+
for val in range(6, 0, -1):
19+
sox_utils.set_verbosity(val)
20+
21+
def test_set_buffer_size(self):
22+
"""`set_buffer_size` does not crush"""
23+
sox_utils.set_buffer_size(131072)
24+
# back to default
25+
sox_utils.set_buffer_size(8192)
26+
27+
def test_set_use_threads(self):
28+
"""`set_use_threads` does not crush"""
29+
sox_utils.set_use_threads(True)
30+
# back to default
31+
sox_utils.set_use_threads(False)
32+
33+
def test_list_effects(self):
34+
"""`list_effects` returns the list of available effects"""
35+
effects = sox_utils.list_effects()
36+
# We cannot infer what effects are available, so only check some of them.
37+
assert 'highpass' in effects
38+
assert 'phaser' in effects
39+
assert 'gain' in effects
40+
41+
def test_list_formats(self):
42+
"""`list_formats` returns the list of supported formats"""
43+
formats = sox_utils.list_formats()
44+
assert 'wav' in formats

torchaudio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
compliance,
55
datasets,
66
kaldi_io,
7+
utils,
78
sox_effects,
89
transforms
910
)

torchaudio/csrc/register.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ static auto registerTensorSignal =
1818
.def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate)
1919
.def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst);
2020

21+
static auto registerSetSoxOptions =
22+
torch::RegisterOperators()
23+
.op("torchaudio::sox_utils_set_seed", &sox_utils::set_seed)
24+
.op("torchaudio::sox_utils_set_verbosity", &sox_utils::set_verbosity)
25+
.op("torchaudio::sox_utils_set_use_threads",
26+
&sox_utils::set_use_threads)
27+
.op("torchaudio::sox_utils_set_buffer_size",
28+
&sox_utils::set_buffer_size)
29+
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
30+
.op("torchaudio::sox_utils_list_formats", &sox_utils::list_formats);
31+
2132
////////////////////////////////////////////////////////////////////////////////
2233
// sox_io.h
2334
////////////////////////////////////////////////////////////////////////////////
@@ -53,12 +64,11 @@ static auto registerSaveAudioFile = torch::RegisterOperators().op(
5364
// sox_effects.h
5465
////////////////////////////////////////////////////////////////////////////////
5566
static auto registerSoxEffects =
56-
torch::RegisterOperators(
57-
"torchaudio::sox_effects_initialize_sox_effects",
58-
&sox_effects::initialize_sox_effects)
67+
torch::RegisterOperators()
68+
.op("torchaudio::sox_effects_initialize_sox_effects",
69+
&sox_effects::initialize_sox_effects)
5970
.op("torchaudio::sox_effects_shutdown_sox_effects",
60-
&sox_effects::shutdown_sox_effects)
61-
.op("torchaudio::sox_effects_list_effects", &sox_effects::list_effects);
71+
&sox_effects::shutdown_sox_effects);
6272

6373
} // namespace
6474
} // namespace torchaudio

torchaudio/csrc/sox_effects.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,5 @@ void shutdown_sox_effects() {
3939
}
4040
}
4141

42-
std::vector<std::string> list_effects() {
43-
std::vector<std::string> names;
44-
const sox_effect_fn_t* fns = sox_get_effect_fns();
45-
for (int i = 0; fns[i]; ++i) {
46-
const sox_effect_handler_t* handler = fns[i]();
47-
if (handler && handler->name)
48-
names.push_back(handler->name);
49-
}
50-
return names;
51-
}
52-
5342
} // namespace sox_effects
5443
} // namespace torchaudio

torchaudio/csrc/sox_effects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ void initialize_sox_effects();
1010

1111
void shutdown_sox_effects();
1212

13-
std::vector<std::string> list_effects();
14-
1513
} // namespace sox_effects
1614
} // namespace torchaudio
1715

torchaudio/csrc/sox_io.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,12 @@ void save_audio_file(
125125
const c10::intrusive_ptr<TensorSignal>& signal,
126126
const double compression) {
127127
const auto tensor = signal->getTensor();
128-
const auto sample_rate = signal->getSampleRate();
129128
const auto channels_first = signal->getChannelsFirst();
130129

131130
validate_input_tensor(tensor);
132131

133132
const auto filetype = get_filetype(file_name);
134-
const auto signal_info =
135-
get_signalinfo(tensor, sample_rate, channels_first, filetype);
133+
const auto signal_info = get_signalinfo(signal.get(), filetype);
136134
const auto encoding_info =
137135
get_encodinginfo(filetype, tensor.dtype(), compression);
138136

torchaudio/csrc/sox_utils.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,49 @@
55
namespace torchaudio {
66
namespace sox_utils {
77

8+
void set_seed(const int64_t seed) {
9+
sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
10+
}
11+
12+
void set_verbosity(const int64_t verbosity) {
13+
sox_get_globals()->verbosity = static_cast<unsigned>(verbosity);
14+
}
15+
16+
void set_use_threads(const bool use_threads) {
17+
sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads);
18+
}
19+
20+
void set_buffer_size(const int64_t buffer_size) {
21+
sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
22+
}
23+
24+
std::vector<std::vector<std::string>> list_effects() {
25+
std::vector<std::vector<std::string>> effects;
26+
for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
27+
const sox_effect_handler_t* handler = (*fns)();
28+
if (handler && handler->name) {
29+
if (UNSUPPORTED_EFFECTS.find(handler->name) ==
30+
UNSUPPORTED_EFFECTS.end()) {
31+
effects.emplace_back(std::vector<std::string>{
32+
handler->name,
33+
handler->usage ? std::string(handler->usage) : std::string("")});
34+
}
35+
}
36+
}
37+
return effects;
38+
}
39+
40+
std::vector<std::string> list_formats() {
41+
std::vector<std::string> formats;
42+
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
43+
for (const char* const* names = fns->fn()->names; *names; ++names) {
44+
if (!strchr(*names, '/'))
45+
formats.emplace_back(*names);
46+
}
47+
}
48+
return formats;
49+
}
50+
851
TensorSignal::TensorSignal(
952
torch::Tensor tensor_,
1053
int64_t sample_rate_,
@@ -205,13 +248,13 @@ unsigned get_precision(
205248
}
206249

207250
sox_signalinfo_t get_signalinfo(
208-
const torch::Tensor& tensor,
209-
const int64_t sample_rate,
210-
const bool channels_first,
251+
const TensorSignal* signal,
211252
const std::string filetype) {
253+
auto tensor = signal->getTensor();
212254
return sox_signalinfo_t{
213-
/*rate=*/static_cast<sox_rate_t>(sample_rate),
214-
/*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
255+
/*rate=*/static_cast<sox_rate_t>(signal->getSampleRate()),
256+
/*channels=*/
257+
static_cast<unsigned>(tensor.size(signal->getChannelsFirst() ? 0 : 1)),
215258
/*precision=*/get_precision(filetype, tensor.dtype()),
216259
/*length=*/static_cast<uint64_t>(tensor.numel())};
217260
}

torchaudio/csrc/sox_utils.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
namespace torchaudio {
88
namespace sox_utils {
99

10+
////////////////////////////////////////////////////////////////////////////////
11+
// APIs for Python interaction
12+
////////////////////////////////////////////////////////////////////////////////
13+
14+
/// Set sox global options
15+
void set_seed(const int64_t seed);
16+
17+
void set_verbosity(const int64_t verbosity);
18+
19+
void set_use_threads(const bool use_threads);
20+
21+
void set_buffer_size(const int64_t buffer_size);
22+
23+
std::vector<std::vector<std::string>> list_effects();
24+
25+
std::vector<std::string> list_formats();
26+
27+
/// Class for exchanging signal infomation (tensor + meta data) between
28+
/// C++ and Python for read/write operation.
1029
struct TensorSignal : torch::CustomClassHolder {
1130
torch::Tensor tensor;
1231
int64_t sample_rate;
@@ -22,6 +41,13 @@ struct TensorSignal : torch::CustomClassHolder {
2241
bool getChannelsFirst() const;
2342
};
2443

44+
////////////////////////////////////////////////////////////////////////////////
45+
// Utilities for sox_io / sox_effects implementations
46+
////////////////////////////////////////////////////////////////////////////////
47+
48+
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS =
49+
{"input", "output", "spectrogram", "noiseprof", "noisered", "splice"};
50+
2551
/// helper class to automatically close sox_format_t*
2652
struct SoxFormat {
2753
explicit SoxFormat(sox_format_t* fd) noexcept;
@@ -84,9 +110,7 @@ const std::string get_filetype(const std::string path);
84110

85111
/// Get sox_signalinfo_t for passing a torch::Tensor object.
86112
sox_signalinfo_t get_signalinfo(
87-
const torch::Tensor& tensor,
88-
const int64_t sample_rate,
89-
const bool channels_first,
113+
const TensorSignal* signal,
90114
const std::string filetype);
91115

92116
/// Get sox_encofinginfo_t for saving audoi file

torchaudio/sox_effects/sox_effects.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
module_utils as _mod_utils,
88
misc_ops as _misc_ops,
99
)
10+
from torchaudio.utils.sox_utils import list_effects
1011

1112
if _mod_utils.is_module_available('torchaudio._torchaudio'):
1213
from torchaudio import _torchaudio
@@ -52,7 +53,7 @@ def effect_names() -> List[str]:
5253
Example
5354
>>> EFFECT_NAMES = torchaudio.sox_effects.effect_names()
5455
"""
55-
return torch.ops.torchaudio.sox_effects_list_effects()
56+
return list(list_effects().keys())
5657

5758

5859
@_mod_utils.requires_module('torchaudio._torchaudio')

0 commit comments

Comments
 (0)