Skip to content

Commit d8b8ccc

Browse files
committed
Simplify C++ registration with TORCH_LIBRARY
1 parent 748286a commit d8b8ccc

File tree

1 file changed

+62
-78
lines changed

1 file changed

+62
-78
lines changed

torchaudio/csrc/register.cpp

Lines changed: 62 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,86 +5,70 @@
55
#include <torchaudio/csrc/sox_io.h>
66
#include <torchaudio/csrc/sox_utils.h>
77

8-
namespace torchaudio {
9-
namespace {
8+
TORCH_LIBRARY(torchaudio, m) {
9+
//////////////////////////////////////////////////////////////////////////////
10+
// sox_utils.h
11+
//////////////////////////////////////////////////////////////////////////////
12+
m.class_<torchaudio::sox_utils::TensorSignal>("TensorSignal")
13+
.def(torch::init<torch::Tensor, int64_t, bool>())
14+
.def("get_tensor", &torchaudio::sox_utils::TensorSignal::getTensor)
15+
.def(
16+
"get_sample_rate",
17+
&torchaudio::sox_utils::TensorSignal::getSampleRate)
18+
.def(
19+
"get_channels_first",
20+
&torchaudio::sox_utils::TensorSignal::getChannelsFirst);
1021

11-
////////////////////////////////////////////////////////////////////////////////
12-
// sox_utils.h
13-
////////////////////////////////////////////////////////////////////////////////
14-
static auto registerTensorSignal =
15-
torch::class_<sox_utils::TensorSignal>("torchaudio", "TensorSignal")
16-
.def(torch::init<torch::Tensor, int64_t, bool>())
17-
.def("get_tensor", &sox_utils::TensorSignal::getTensor)
18-
.def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate)
19-
.def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst);
22+
m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
23+
m.def(
24+
"torchaudio::sox_utils_set_verbosity",
25+
&torchaudio::sox_utils::set_verbosity);
26+
m.def(
27+
"torchaudio::sox_utils_set_use_threads",
28+
&torchaudio::sox_utils::set_use_threads);
29+
m.def(
30+
"torchaudio::sox_utils_set_buffer_size",
31+
&torchaudio::sox_utils::set_buffer_size);
32+
m.def(
33+
"torchaudio::sox_utils_list_effects",
34+
&torchaudio::sox_utils::list_effects);
35+
m.def(
36+
"torchaudio::sox_utils_list_read_formats",
37+
&torchaudio::sox_utils::list_read_formats);
38+
m.def(
39+
"torchaudio::sox_utils_list_write_formats",
40+
&torchaudio::sox_utils::list_write_formats);
2041

21-
static auto registerSetSoxOptions =
22-
torch::RegisterOperators()
23-
.op("torchaudio::sox_utils_set_seed", &sox_utils::set_seed)
24-
.op("torchaudio::sox_utils_set_verbosity", &sox_utils::set_verbosity)
25-
.op("torchaudio::sox_utils_set_use_threads",
26-
&sox_utils::set_use_threads)
27-
.op("torchaudio::sox_utils_set_buffer_size",
28-
&sox_utils::set_buffer_size)
29-
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
30-
.op("torchaudio::sox_utils_list_read_formats",
31-
&sox_utils::list_read_formats)
32-
.op("torchaudio::sox_utils_list_write_formats",
33-
&sox_utils::list_write_formats);
42+
//////////////////////////////////////////////////////////////////////////////
43+
// sox_io.h
44+
//////////////////////////////////////////////////////////////////////////////
45+
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
46+
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
47+
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
48+
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);
3449

35-
////////////////////////////////////////////////////////////////////////////////
36-
// sox_io.h
37-
////////////////////////////////////////////////////////////////////////////////
38-
static auto registerSignalInfo =
39-
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
40-
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
41-
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
42-
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
50+
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
51+
m.def(
52+
"torchaudio::sox_io_load_audio_file",
53+
&torchaudio::sox_io::load_audio_file);
54+
m.def(
55+
"torchaudio::sox_io_save_audio_file",
56+
&torchaudio::sox_io::save_audio_file);
4357

44-
static auto registerGetInfo = torch::RegisterOperators().op(
45-
torch::RegisterOperators::options()
46-
.schema(
47-
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info")
48-
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>());
49-
50-
static auto registerLoadAudioFile = torch::RegisterOperators().op(
51-
torch::RegisterOperators::options()
52-
.schema(
53-
"torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal")
54-
.catchAllKernel<
55-
decltype(sox_io::load_audio_file),
56-
&sox_io::load_audio_file>());
57-
58-
static auto registerSaveAudioFile = torch::RegisterOperators().op(
59-
torch::RegisterOperators::options()
60-
.schema(
61-
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()")
62-
.catchAllKernel<
63-
decltype(sox_io::save_audio_file),
64-
&sox_io::save_audio_file>());
65-
66-
////////////////////////////////////////////////////////////////////////////////
67-
// sox_effects.h
68-
////////////////////////////////////////////////////////////////////////////////
69-
static auto registerSoxEffects =
70-
torch::RegisterOperators()
71-
.op("torchaudio::sox_effects_initialize_sox_effects",
72-
&sox_effects::initialize_sox_effects)
73-
.op("torchaudio::sox_effects_shutdown_sox_effects",
74-
&sox_effects::shutdown_sox_effects)
75-
.op(torch::RegisterOperators::options()
76-
.schema(
77-
"torchaudio::sox_effects_apply_effects_tensor(__torch__.torch.classes.torchaudio.TensorSignal input_signal, str[][] effects) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
78-
.catchAllKernel<
79-
decltype(sox_effects::apply_effects_tensor),
80-
&sox_effects::apply_effects_tensor>())
81-
.op(torch::RegisterOperators::options()
82-
.schema(
83-
"torchaudio::sox_effects_apply_effects_file(str path, str[][] effects, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
84-
.catchAllKernel<
85-
decltype(sox_effects::apply_effects_file),
86-
&sox_effects::apply_effects_file>());
87-
88-
} // namespace
89-
} // namespace torchaudio
58+
//////////////////////////////////////////////////////////////////////////////
59+
// sox_effects.h
60+
//////////////////////////////////////////////////////////////////////////////
61+
m.def(
62+
"torchaudio::sox_effects_initialize_sox_effects",
63+
&torchaudio::sox_effects::initialize_sox_effects);
64+
m.def(
65+
"torchaudio::sox_effects_shutdown_sox_effects",
66+
&torchaudio::sox_effects::shutdown_sox_effects);
67+
m.def(
68+
"torchaudio::sox_effects_apply_effects_tensor",
69+
&torchaudio::sox_effects::apply_effects_tensor);
70+
m.def(
71+
"torchaudio::sox_effects_apply_effects_file",
72+
&torchaudio::sox_effects::apply_effects_file);
73+
}
9074
#endif

0 commit comments

Comments
 (0)