Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fe29555
Merge pull request #1 from pytorch/master
yoyolicoris Feb 19, 2021
f695635
Merge pull request #4 from pytorch/master
yoyolicoris Feb 26, 2021
68f551b
Merge pull request #9 from pytorch/master
yoyolicoris Mar 16, 2021
537a153
Merge pull request #10 from pytorch/master
yoyolicoris Apr 1, 2021
86fcc3d
Merge branch 'master' of https://github.com/yoyololicon/audio
yoyolicoris Apr 3, 2021
e0d18a0
refactor lfilter cpp implementation
yoyolicoris Apr 6, 2021
3b17e40
add comments
yoyolicoris Apr 7, 2021
2bad0fe
feat: normalize a_coeff outside custom function
yoyolicoris Apr 7, 2021
38d1b4f
test: add gradgradcheck
yoyolicoris Apr 7, 2021
92fe68f
refactor: remote iir_autograd
yoyolicoris Apr 8, 2021
01b73e1
Merge branch 'master' of https://github.com/yoyololicon/audio
yoyolicoris Apr 8, 2021
485edb0
Merge branch 'master' into lfilter-higher-order-gradient
yoyolicoris Apr 8, 2021
20cc31d
test: add 2nd order gradient check
yoyolicoris Apr 8, 2021
1ebaa02
test: use shorter sequence for faster runtime
yoyolicoris Apr 8, 2021
bf2df52
test: use even shorter inputs
yoyolicoris Apr 8, 2021
dd4a3c0
Merge branch 'master' into lfilter-higher-order-gradient
yoyolicoris Apr 8, 2021
bb1a792
feat: add custom FIR for performance reason
yoyolicoris May 3, 2021
e490090
remove unnecessarry autograd guard
yoyolicoris May 4, 2021
1bc3f45
test: bring back b_coeff test
yoyolicoris May 5, 2021
6b1536f
Merge remote-tracking branch 'upstream/master' into lfilter-higher-or…
yoyolicoris May 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from parameterized import parameterized
from torch import Tensor
import torchaudio.functional as F
from torch.autograd import gradcheck
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
Expand All @@ -26,6 +26,7 @@ def assert_grad(
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_)

def test_lfilter_x(self):
torch.random.manual_seed(2434)
Expand Down
231 changes: 108 additions & 123 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,170 +80,159 @@ void lfilter_core_generic_loop(
}
}

std::tuple<at::Tensor, at::Tensor> lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_sample_padded = n_sample + n_order - 1;

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous();

auto device = waveform.device();
int64_t n_order = a_coeffs.size(0);
auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options);

if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
}

TORCH_INTERNAL_ASSERT(n_order > 0);
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

namespace F = torch::nn::functional;
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
}

auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto padded_output_waveform = torch::zeros_like(padded_waveform);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
int64_t n_channel = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0);

auto input_signal_windows =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];

input_signal_windows.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
namespace F = torch::nn::functional;

if (device.is_cpu()) {
cpu_lfilter_core_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
}
if (a_coeffs_normalized.requires_grad()) {
auto dyda = F::pad(
DifferentiableIIR::apply(-y, a_coeffs_normalized),
F::PadFuncOptions({n_order - 1, 0}));

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
}

return {output, input_signal_windows};
}
if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1);
}

torch::Tensor lfilter_simple(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return std::get<0>(lfilter_core(waveform, a_coeffs, b_coeffs));
}
return {dx, da};
}
};

class DifferentiableLfilter
: public torch::autograd::Function<DifferentiableLfilter> {
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
at::AutoNonVariableTypeMode g;
auto result = lfilter_core(waveform, a_coeffs, b_coeffs);
ctx->save_for_backward(
{waveform,
a_coeffs,
b_coeffs,
std::get<0>(result),
std::get<1>(result)});
return std::get<0>(result);
int64_t n_order = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto waveform = saved[0];
auto a_coeffs = saved[1];
auto b_coeffs = saved[2];
auto y = saved[3];
auto xh = saved[4];

auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs.size(0);
int64_t n_sample_padded = n_sample + n_order - 1;
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];

auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
b_coeff_flipped.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
int64_t n_channel = x.size(0);
int64_t n_order = b_coeffs.size(0);

auto dx = torch::Tensor();
auto da = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];

at::AutoNonVariableTypeMode g;
namespace F = torch::nn::functional;
auto options = torch::TensorOptions().dtype(dtype).device(device);

if (a_coeffs.requires_grad()) {
auto dyda = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(-y, a_coeff_flipped, dyda);
} else {
lfilter_core_generic_loop(-y, a_coeff_flipped, dyda);
}

da = F::conv1d(
dyda.unsqueeze(0),
if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
da.div_(a_coeffs[0]);
}

if (b_coeffs.requires_grad() || waveform.requires_grad()) {
auto dxh = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(dy.flip(1), a_coeff_flipped, dxh);
} else {
lfilter_core_generic_loop(dy.flip(1), a_coeff_flipped, dxh);
}

dxh = dxh.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)})
.flip(1);

if (waveform.requires_grad()) {
dx = F::conv1d(
F::pad(dxh.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
dx.div_(a_coeffs[0]);
}
if (b_coeffs.requires_grad()) {
db =
F::conv1d(
F::pad(
waveform.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dxh.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
db.div_(a_coeffs[0]);
}
if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
}

return {dx, da, db};
return {dx, db};
}
};

torch::Tensor lfilter_autograd(
torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return DifferentiableLfilter::apply(waveform, a_coeffs, b_coeffs);
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);

int64_t n_order = b_coeffs.size(0);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform =
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);

auto output =
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
return output;
}

} // namespace
Expand All @@ -259,10 +248,6 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchaudio, DefaultBackend, m) {
m.impl("torchaudio::_lfilter", lfilter_simple);
}

TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("torchaudio::_lfilter", lfilter_autograd);
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
}