Skip to content

Commit e994869

Browse files
committed
Get rid of typedefs/SignalInfo and replace AudioMetaData
1 parent e43ee19 commit e994869

File tree

10 files changed

+75
-115
lines changed

10 files changed

+75
-115
lines changed

test/sox_io_backend/test_info.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def test_wav(self, dtype, sample_rate, num_channels):
3333
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
3434
save_wav(path, data, sample_rate)
3535
info = sox_io_backend.info(path)
36-
assert info.get_sample_rate() == sample_rate
37-
assert info.get_num_frames() == sample_rate * duration
38-
assert info.get_num_channels() == num_channels
36+
assert info.sample_rate == sample_rate
37+
assert info.num_frames == sample_rate * duration
38+
assert info.num_channels == num_channels
3939

4040
@parameterized.expand(list(itertools.product(
4141
['float32', 'int32', 'int16', 'uint8'],
@@ -49,9 +49,9 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
4949
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
5050
save_wav(path, data, sample_rate)
5151
info = sox_io_backend.info(path)
52-
assert info.get_sample_rate() == sample_rate
53-
assert info.get_num_frames() == sample_rate * duration
54-
assert info.get_num_channels() == num_channels
52+
assert info.sample_rate == sample_rate
53+
assert info.num_frames == sample_rate * duration
54+
assert info.num_channels == num_channels
5555

5656
@parameterized.expand(list(itertools.product(
5757
[8000, 16000],
@@ -67,10 +67,10 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
6767
compression=bit_rate, duration=duration,
6868
)
6969
info = sox_io_backend.info(path)
70-
assert info.get_sample_rate() == sample_rate
70+
assert info.sample_rate == sample_rate
7171
# mp3 does not preserve the number of samples
72-
# assert info.get_num_frames() == sample_rate * duration
73-
assert info.get_num_channels() == num_channels
72+
# assert info.num_frames == sample_rate * duration
73+
assert info.num_channels == num_channels
7474

7575
@parameterized.expand(list(itertools.product(
7676
[8000, 16000],
@@ -86,9 +86,9 @@ def test_flac(self, sample_rate, num_channels, compression_level):
8686
compression=compression_level, duration=duration,
8787
)
8888
info = sox_io_backend.info(path)
89-
assert info.get_sample_rate() == sample_rate
90-
assert info.get_num_frames() == sample_rate * duration
91-
assert info.get_num_channels() == num_channels
89+
assert info.sample_rate == sample_rate
90+
assert info.num_frames == sample_rate * duration
91+
assert info.num_channels == num_channels
9292

9393
@parameterized.expand(list(itertools.product(
9494
[8000, 16000],
@@ -104,9 +104,9 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
104104
compression=quality_level, duration=duration,
105105
)
106106
info = sox_io_backend.info(path)
107-
assert info.get_sample_rate() == sample_rate
108-
assert info.get_num_frames() == sample_rate * duration
109-
assert info.get_num_channels() == num_channels
107+
assert info.sample_rate == sample_rate
108+
assert info.num_frames == sample_rate * duration
109+
assert info.num_channels == num_channels
110110

111111

112112
@skipIfNoExtension
@@ -120,6 +120,6 @@ def test_opus(self, bitrate, num_channels, compression_level):
120120
"""`sox_io_backend.info` can check opus file correcty"""
121121
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
122122
info = sox_io_backend.info(path)
123-
assert info.get_sample_rate() == 48000
124-
assert info.get_num_frames() == 32768
125-
assert info.get_num_channels() == num_channels
123+
assert info.sample_rate == 48000
124+
assert info.num_frames == 32768
125+
assert info.num_channels == num_channels

test/sox_io_backend/test_torchscript.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
23+
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
2424
return torchaudio.info(filepath)
2525

2626

@@ -63,9 +63,9 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
6363
py_info = py_info_func(audio_path)
6464
ts_info = ts_info_func(audio_path)
6565

66-
assert py_info.get_sample_rate() == ts_info.get_sample_rate()
67-
assert py_info.get_num_frames() == ts_info.get_num_frames()
68-
assert py_info.get_num_channels() == ts_info.get_num_channels()
66+
assert py_info.sample_rate == ts_info.sample_rate
67+
assert py_info.num_frames == ts_info.num_frames
68+
assert py_info.num_channels == ts_info.num_channels
6969

7070
@parameterized.expand(list(itertools.product(
7171
['float32', 'int32', 'int16', 'uint8'],

torchaudio/backend/sox_io_backend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@
66
)
77

88

9+
class AudioMetaData:
10+
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
11+
self.sample_rate = sample_rate
12+
self.num_frames = num_frames
13+
self.num_channels = num_channels
14+
15+
916
@_mod_utils.requires_module('torchaudio._torchaudio')
10-
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
17+
def info(filepath: str) -> AudioMetaData:
1118
"""Get signal information of an audio file."""
12-
return torch.ops.torchaudio.sox_io_get_info(filepath)
19+
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
20+
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
1321

1422

1523
@_mod_utils.requires_module('torchaudio._torchaudio')

torchaudio/csrc/register.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,10 @@
44
#include <torchaudio/csrc/sox_effects.h>
55
#include <torchaudio/csrc/sox_io.h>
66
#include <torchaudio/csrc/sox_utils.h>
7-
#include <torchaudio/csrc/typedefs.h>
87

98
namespace torchaudio {
109
namespace {
1110

12-
////////////////////////////////////////////////////////////////////////////////
13-
// typedefs.h
14-
////////////////////////////////////////////////////////////////////////////////
15-
static auto registerSignalInfo =
16-
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
17-
.def(torch::init<int64_t, int64_t, int64_t>())
18-
.def("get_sample_rate", &SignalInfo::getSampleRate)
19-
.def("get_num_channels", &SignalInfo::getNumChannels)
20-
.def("get_num_frames", &SignalInfo::getNumFrames);
21-
2211
////////////////////////////////////////////////////////////////////////////////
2312
// sox_utils.h
2413
////////////////////////////////////////////////////////////////////////////////
@@ -32,6 +21,12 @@ static auto registerTensorSignal =
3221
////////////////////////////////////////////////////////////////////////////////
3322
// sox_io.h
3423
////////////////////////////////////////////////////////////////////////////////
24+
static auto registerSignalInfo =
25+
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
26+
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
27+
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
28+
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
29+
3530
static auto registerGetInfo = torch::RegisterOperators().op(
3631
torch::RegisterOperators::options()
3732
.schema(

torchaudio/csrc/sox_effects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define TORCHAUDIO_SOX_EFFECTS_H
33

44
#include <torch/script.h>
5-
#include <torchaudio/csrc/typedefs.h>
65

76
namespace torchaudio {
87
namespace sox_effects {

torchaudio/csrc/sox_io.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
88
namespace torchaudio {
99
namespace sox_io {
1010

11-
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
11+
SignalInfo::SignalInfo(
12+
const int64_t sample_rate_,
13+
const int64_t num_channels_,
14+
const int64_t num_frames_)
15+
: sample_rate(sample_rate_),
16+
num_channels(num_channels_),
17+
num_frames(num_frames_){};
18+
19+
int64_t SignalInfo::getSampleRate() const {
20+
return sample_rate;
21+
}
22+
23+
int64_t SignalInfo::getNumChannels() const {
24+
return num_channels;
25+
}
26+
27+
int64_t SignalInfo::getNumFrames() const {
28+
return num_frames;
29+
}
30+
31+
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
1232
SoxFormat sf(sox_open_read(
1333
path.c_str(),
1434
/*signal=*/nullptr,
@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
1939
throw std::runtime_error("Error opening audio file");
2040
}
2141

22-
return c10::make_intrusive<torchaudio::SignalInfo>(
42+
return c10::make_intrusive<SignalInfo>(
2343
static_cast<int64_t>(sf->signal.rate),
2444
static_cast<int64_t>(sf->signal.channels),
2545
static_cast<int64_t>(sf->signal.length / sf->signal.channels));

torchaudio/csrc/sox_io.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33

44
#include <torch/script.h>
55
#include <torchaudio/csrc/sox_utils.h>
6-
#include <torchaudio/csrc/typedefs.h>
76

87
namespace torchaudio {
98
namespace sox_io {
109

11-
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path);
10+
struct SignalInfo : torch::CustomClassHolder {
11+
int64_t sample_rate;
12+
int64_t num_channels;
13+
int64_t num_frames;
14+
15+
SignalInfo(
16+
const int64_t sample_rate_,
17+
const int64_t num_channels_,
18+
const int64_t num_frames_);
19+
int64_t getSampleRate() const;
20+
int64_t getNumChannels() const;
21+
int64_t getNumFrames() const;
22+
};
23+
24+
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
1225

1326
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
1427
const std::string& path,

torchaudio/csrc/typedefs.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

torchaudio/csrc/typedefs.h

Lines changed: 0 additions & 23 deletions
This file was deleted.

torchaudio/extension/extension.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,9 @@ def _init_extension():
1212
_init_script_module(ext)
1313
else:
1414
warnings.warn('torchaudio C++ extension is not available.')
15-
_init_dummy_module()
1615

1716

1817
def _init_script_module(module):
1918
path = importlib.util.find_spec(module).origin
2019
torch.classes.load_library(path)
2120
torch.ops.load_library(path)
22-
23-
24-
def _init_dummy_module():
25-
class SignalInfo:
26-
"""Data class for audio format information
27-
28-
Used when torchaudio C++ extension is not available for annotating
29-
sox_io backend functions so that torchaudio is still importable
30-
without extension.
31-
This class has to implement the same interface as C++ equivalent.
32-
"""
33-
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
34-
self.sample_rate = sample_rate
35-
self.num_channels = num_channels
36-
self.num_frames = num_frames
37-
38-
def get_sample_rate(self):
39-
return self.sample_rate
40-
41-
def get_num_channels(self):
42-
return self.num_channels
43-
44-
def get_num_frames(self):
45-
return self.num_frames
46-
47-
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
48-
module = DummyModule(SignalInfo)
49-
setattr(torch.classes, 'torchaudio', module)

0 commit comments

Comments
 (0)