diff --git a/test/sox_io_backend/test_info.py b/test/sox_io_backend/test_info.py index 2b28ae4814..da5207a7e5 100644 --- a/test/sox_io_backend/test_info.py +++ b/test/sox_io_backend/test_info.py @@ -33,9 +33,9 @@ def test_wav(self, dtype, sample_rate, num_channels): data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) - assert info.get_sample_rate() == sample_rate - assert info.get_num_frames() == sample_rate * duration - assert info.get_num_channels() == num_channels + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], @@ -49,9 +49,9 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) - assert info.get_sample_rate() == sample_rate - assert info.get_num_frames() == sample_rate * duration - assert info.get_num_channels() == num_channels + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], @@ -67,10 +67,10 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): compression=bit_rate, duration=duration, ) info = sox_io_backend.info(path) - assert info.get_sample_rate() == sample_rate + assert info.sample_rate == sample_rate # mp3 does not preserve the number of samples - # assert info.get_num_frames() == sample_rate * duration - assert info.get_num_channels() == num_channels + # assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], @@ -86,9 +86,9 @@ def test_flac(self, sample_rate, num_channels, compression_level): compression=compression_level, duration=duration, ) info = sox_io_backend.info(path) - assert info.get_sample_rate() == sample_rate - assert info.get_num_frames() == sample_rate * duration - assert info.get_num_channels() == num_channels + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], @@ -104,9 +104,9 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): compression=quality_level, duration=duration, ) info = sox_io_backend.info(path) - assert info.get_sample_rate() == sample_rate - assert info.get_num_frames() == sample_rate * duration - assert info.get_num_channels() == num_channels + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels @skipIfNoExtension @@ -120,6 +120,6 @@ def test_opus(self, bitrate, num_channels, compression_level): """`sox_io_backend.info` can check opus file correcty""" path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') info = sox_io_backend.info(path) - assert info.get_sample_rate() == 48000 - assert info.get_num_frames() == 32768 - assert info.get_num_channels() == num_channels + assert info.sample_rate == 48000 + assert info.num_frames == 32768 + assert info.num_channels == num_channels diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index 522c1319ce..9a30aab0d2 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -20,7 +20,7 @@ ) -def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: +def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData: return torchaudio.info(filepath) @@ -63,9 +63,9 @@ def test_info_wav(self, dtype, sample_rate, num_channels): py_info = py_info_func(audio_path) ts_info = ts_info_func(audio_path) - assert py_info.get_sample_rate() == ts_info.get_sample_rate() - assert py_info.get_num_frames() == ts_info.get_num_frames() - assert py_info.get_num_channels() == ts_info.get_num_channels() + assert py_info.sample_rate == ts_info.sample_rate + assert py_info.num_frames == ts_info.num_frames + assert py_info.num_channels == ts_info.num_channels @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 2fef97917f..ced7b38d4b 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -6,10 +6,18 @@ ) +class AudioMetaData: + def __init__(self, sample_rate: int, num_frames: int, num_channels: int): + self.sample_rate = sample_rate + self.num_frames = num_frames + self.num_channels = num_channels + + @_mod_utils.requires_module('torchaudio._torchaudio') -def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: +def info(filepath: str) -> AudioMetaData: """Get signal information of an audio file.""" - return torch.ops.torchaudio.sox_io_get_info(filepath) + sinfo = torch.ops.torchaudio.sox_io_get_info(filepath) + return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) @_mod_utils.requires_module('torchaudio._torchaudio') diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 538c3f0ea0..1def910ac9 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -4,21 +4,10 @@ #include #include #include -#include namespace torchaudio { namespace { -//////////////////////////////////////////////////////////////////////////////// -// typedefs.h -//////////////////////////////////////////////////////////////////////////////// -static auto registerSignalInfo = - torch::class_("torchaudio", "SignalInfo") - .def(torch::init()) - .def("get_sample_rate", &SignalInfo::getSampleRate) - .def("get_num_channels", &SignalInfo::getNumChannels) - .def("get_num_frames", &SignalInfo::getNumFrames); - //////////////////////////////////////////////////////////////////////////////// // sox_utils.h //////////////////////////////////////////////////////////////////////////////// @@ -32,6 +21,12 @@ static auto registerTensorSignal = //////////////////////////////////////////////////////////////////////////////// // sox_io.h //////////////////////////////////////////////////////////////////////////////// +static auto registerSignalInfo = + torch::class_("torchaudio", "SignalInfo") + .def("get_sample_rate", &sox_io::SignalInfo::getSampleRate) + .def("get_num_channels", &sox_io::SignalInfo::getNumChannels) + .def("get_num_frames", &sox_io::SignalInfo::getNumFrames); + static auto registerGetInfo = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( diff --git a/torchaudio/csrc/sox_effects.h b/torchaudio/csrc/sox_effects.h index 6e4a26628f..14bdbbfabc 100644 --- a/torchaudio/csrc/sox_effects.h +++ b/torchaudio/csrc/sox_effects.h @@ -2,7 +2,6 @@ #define TORCHAUDIO_SOX_EFFECTS_H #include -#include namespace torchaudio { namespace sox_effects { diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 1870dd8c9f..ec62a57ce6 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils; namespace torchaudio { namespace sox_io { -c10::intrusive_ptr get_info(const std::string& path) { +SignalInfo::SignalInfo( + const int64_t sample_rate_, + const int64_t num_channels_, + const int64_t num_frames_) + : sample_rate(sample_rate_), + num_channels(num_channels_), + num_frames(num_frames_){}; + +int64_t SignalInfo::getSampleRate() const { + return sample_rate; +} + +int64_t SignalInfo::getNumChannels() const { + return num_channels; +} + +int64_t SignalInfo::getNumFrames() const { + return num_frames; +} + +c10::intrusive_ptr get_info(const std::string& path) { SoxFormat sf(sox_open_read( path.c_str(), /*signal=*/nullptr, @@ -19,7 +39,7 @@ c10::intrusive_ptr get_info(const std::string& path) { throw std::runtime_error("Error opening audio file"); } - return c10::make_intrusive( + return c10::make_intrusive( static_cast(sf->signal.rate), static_cast(sf->signal.channels), static_cast(sf->signal.length / sf->signal.channels)); diff --git a/torchaudio/csrc/sox_io.h b/torchaudio/csrc/sox_io.h index 310687bb7d..6384f983be 100644 --- a/torchaudio/csrc/sox_io.h +++ b/torchaudio/csrc/sox_io.h @@ -3,12 +3,25 @@ #include #include -#include namespace torchaudio { namespace sox_io { -c10::intrusive_ptr get_info(const std::string& path); +struct SignalInfo : torch::CustomClassHolder { + int64_t sample_rate; + int64_t num_channels; + int64_t num_frames; + + SignalInfo( + const int64_t sample_rate_, + const int64_t num_channels_, + const int64_t num_frames_); + int64_t getSampleRate() const; + int64_t getNumChannels() const; + int64_t getNumFrames() const; +}; + +c10::intrusive_ptr get_info(const std::string& path); c10::intrusive_ptr load_audio_file( const std::string& path, diff --git a/torchaudio/csrc/typedefs.cpp b/torchaudio/csrc/typedefs.cpp deleted file mode 100644 index f4136cc918..0000000000 --- a/torchaudio/csrc/typedefs.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - -namespace torchaudio { -SignalInfo::SignalInfo( - const int64_t sample_rate_, - const int64_t num_channels_, - const int64_t num_frames_) - : sample_rate(sample_rate_), - num_channels(num_channels_), - num_frames(num_frames_){}; - -int64_t SignalInfo::getSampleRate() const { - return sample_rate; -} - -int64_t SignalInfo::getNumChannels() const { - return num_channels; -} - -int64_t SignalInfo::getNumFrames() const { - return num_frames; -} -} // namespace torchaudio diff --git a/torchaudio/csrc/typedefs.h b/torchaudio/csrc/typedefs.h deleted file mode 100644 index ddd210e647..0000000000 --- a/torchaudio/csrc/typedefs.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef TORCHAUDIO_TYPDEFS_H -#define TORCHAUDIO_TYPDEFS_H - -#include - -namespace torchaudio { -struct SignalInfo : torch::CustomClassHolder { - int64_t sample_rate; - int64_t num_channels; - int64_t num_frames; - - SignalInfo( - const int64_t sample_rate_, - const int64_t num_channels_, - const int64_t num_frames_); - int64_t getSampleRate() const; - int64_t getNumChannels() const; - int64_t getNumFrames() const; -}; - -} // namespace torchaudio - -#endif diff --git a/torchaudio/extension/extension.py b/torchaudio/extension/extension.py index 4d8ac4dcba..b01ba13e39 100644 --- a/torchaudio/extension/extension.py +++ b/torchaudio/extension/extension.py @@ -12,38 +12,9 @@ def _init_extension(): _init_script_module(ext) else: warnings.warn('torchaudio C++ extension is not available.') - _init_dummy_module() def _init_script_module(module): path = importlib.util.find_spec(module).origin torch.classes.load_library(path) torch.ops.load_library(path) - - -def _init_dummy_module(): - class SignalInfo: - """Data class for audio format information - - Used when torchaudio C++ extension is not available for annotating - sox_io backend functions so that torchaudio is still importable - without extension. - This class has to implement the same interface as C++ equivalent. - """ - def __init__(self, sample_rate: int, num_channels: int, num_frames: int): - self.sample_rate = sample_rate - self.num_channels = num_channels - self.num_frames = num_frames - - def get_sample_rate(self): - return self.sample_rate - - def get_num_channels(self): - return self.num_channels - - def get_num_frames(self): - return self.num_frames - - DummyModule = namedtuple('torchaudio', ['SignalInfo']) - module = DummyModule(SignalInfo) - setattr(torch.classes, 'torchaudio', module)