Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_wav(self, dtype, sample_rate, num_channels):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
Comment on lines -36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions are still in there so the change is not BC breaking, right? Is there a reason why you are updating the tests then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions are still in there so the change is not BC breaking, right?

No. These get_* functions are removed so this is BC breaking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user tries to use them, what will happen? I don't see any instructions to the user on what to do?


@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand All @@ -49,9 +49,9 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -67,10 +67,10 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.sample_rate == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -86,9 +86,9 @@ def test_flac(self, sample_rate, num_channels, compression_level):
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -104,9 +104,9 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels


@skipIfNoExtension
Expand All @@ -120,6 +120,6 @@ def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
info = sox_io_backend.info(path)
assert info.get_sample_rate() == 48000
assert info.get_num_frames() == 32768
assert info.get_num_channels() == num_channels
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
8 changes: 4 additions & 4 deletions test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.info(filepath)


Expand Down Expand Up @@ -63,9 +63,9 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels()
assert py_info.sample_rate == ts_info.sample_rate
assert py_info.num_frames == ts_info.num_frames
assert py_info.num_channels == ts_info.num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand Down
12 changes: 10 additions & 2 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
)


class AudioMetaData:
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
17 changes: 6 additions & 11 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,10 @@
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace {

////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_frames", &SignalInfo::getNumFrames);

////////////////////////////////////////////////////////////////////////////////
// sox_utils.h
////////////////////////////////////////////////////////////////////////////////
Expand All @@ -32,6 +21,12 @@ static auto registerTensorSignal =
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);

static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
Expand Down
1 change: 0 additions & 1 deletion torchaudio/csrc/sox_effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#define TORCHAUDIO_SOX_EFFECTS_H

#include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace sox_effects {
Expand Down
24 changes: 22 additions & 2 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {

c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}

int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumFrames() const {
return num_frames;
}

c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
Expand All @@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
throw std::runtime_error("Error opening audio file");
}

return c10::make_intrusive<torchaudio::SignalInfo>(
return c10::make_intrusive<SignalInfo>(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
Expand Down
17 changes: 15 additions & 2 deletions torchaudio/csrc/sox_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@

#include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace sox_io {

c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path);
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
};

c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);

c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path,
Expand Down
23 changes: 0 additions & 23 deletions torchaudio/csrc/typedefs.cpp

This file was deleted.

23 changes: 0 additions & 23 deletions torchaudio/csrc/typedefs.h

This file was deleted.

29 changes: 0 additions & 29 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,9 @@ def _init_extension():
_init_script_module(ext)
else:
warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()


def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)


def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information

Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_frames = num_frames

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_frames(self):
return self.num_frames

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)