Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import extension
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
Expand Down
18 changes: 18 additions & 0 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H

#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace {

static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);

} // namespace
} // namespace torchaudio
#endif
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}

int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumSamples() const {
return num_samples;
}
} // namespace torchaudio
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H

#include <torch/script.h>

namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
};

} // namespace torchaudio

#endif
7 changes: 7 additions & 0 deletions torchaudio/extension/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .extension import (
_init_extension,
)

_init_extension()

del _init_extension
49 changes: 49 additions & 0 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import warnings
import importlib
from collections import namedtuple

import torch
from torchaudio._internal import module_utils as _mod_utils


def _init_extension():
ext = 'torchaudio._torchaudio'
if _mod_utils.is_module_available(ext):
_init_script_module(ext)
else:
warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()


def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)


def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information

Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_samples(self):
return self.num_samples

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)