Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ def test_mfcc(self):
class TestMelResNet(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _MelResNet block.
"""

batch_size = 2
num_features = 200
input_dims = 100
output_dims = 128
res_blocks = 10
hidden_dims = 128
pad = 2
n_batch = 2
n_time = 200
n_freq = 100
n_output = 128
n_res_block = 10
n_hidden = 128
kernel_size = 5

model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad)
model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)

x = torch.rand(batch_size, input_dims, num_features)
x = torch.rand(n_batch, n_freq, n_time)
out = model(x)

assert out.size() == (batch_size, output_dims, num_features - pad * 2)
assert out.size() == (n_batch, n_output, n_time - kernel_size + 1)
106 changes: 44 additions & 62 deletions torchaudio/models/_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,83 @@


class _ResBlock(nn.Module):
r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning
for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016.
It is a block used in WaveRNN. WaveRNN is based on the paper "Efficient Neural Audio Synthesis".
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart,
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
r"""ResNet block based on "Deep Residual Learning for Image Recognition"
The paper link is https://arxiv.org/pdf/1512.03385.pdf.
Args:
num_dims: the number of compute dimensions in the input (default=128).
n_freq: the number of bins in a spectrogram (default=128)
Examples::
>>> resblock = _ResBlock(num_dims=128)
>>> input = torch.rand(10, 128, 512)
>>> output = resblock(input)
Examples
>>> resblock = _ResBlock()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = resblock(input) # shape: (10, 128, 512)
"""

def __init__(self, num_dims: int = 128) -> None:
def __init__(self, n_freq: int = 128) -> None:
super().__init__()

self.resblock_model = nn.Sequential(
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
nn.BatchNorm1d(num_dims),
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
nn.BatchNorm1d(n_freq),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
nn.BatchNorm1d(num_dims)
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
nn.BatchNorm1d(n_freq)
)

def forward(self, x: Tensor) -> Tensor:
def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _ResBlock layer.
Args:
x: the input sequence to the _ResBlock layer (required).
specgram (Tensor): the input sequence to the _ResBlock layer (n_batch, n_freq, n_time).
Shape:
- x: :math:`(N, S, T)`.
- output: :math:`(N, S, T)`.
where N is the batch size, S is the number of input sequence,
T is the length of input sequence.
Return:
Tensor shape: (n_batch, n_freq, n_time)
"""

residual = x
return self.resblock_model(x) + residual
return self.resblock_model(specgram) + specgram


class _MelResNet(nn.Module):
r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN.
WaveRNN is based on the paper "Efficient Neural Audio Synthesis". Nal Kalchbrenner, Erich Elsen,
Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, Florian Stimberg, Aaron van den Oord,
Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
Args:
res_blocks: the number of ResBlock in stack (default=10).
input_dims: the number of input sequence (default=100).
hidden_dims: the number of compute dimensions (default=128).
output_dims: the number of output sequence (default=128).
pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2).
Examples::
>>> melresnet = _MelResNet(res_blocks=10, input_dims=100,
hidden_dims=128, output_dims=128, pad=2)
>>> input = torch.rand(10, 100, 512)
>>> output = melresnet(input)
n_res_block: the number of ResBlock in stack (default=10)
n_freq: the number of bins in a spectrogram (default=128)
n_hidden: the number of hidden dimensions (default=128)
n_output: the number of output dimensions (default=128)
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
Examples
>>> melresnet = _MelResNet()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = melresnet(input) # shape: (10, 128, 508)
"""

def __init__(self, res_blocks: int = 10,
input_dims: int = 100,
hidden_dims: int = 128,
output_dims: int = 128,
pad: int = 2) -> None:
def __init__(self,
n_res_block: int = 10,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
kernel_size: int = 5) -> None:
super().__init__()

kernel_size = pad * 2 + 1
ResBlocks = []

for i in range(res_blocks):
ResBlocks.append(_ResBlock(hidden_dims))
ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]

self.melresnet_model = nn.Sequential(
nn.Conv1d(in_channels=input_dims, out_channels=hidden_dims, kernel_size=kernel_size, bias=False),
nn.BatchNorm1d(hidden_dims),
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
nn.BatchNorm1d(n_hidden),
nn.ReLU(inplace=True),
*ResBlocks,
nn.Conv1d(in_channels=hidden_dims, out_channels=output_dims, kernel_size=1)
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1)
)

def forward(self, x: Tensor) -> Tensor:
def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _MelResNet layer.
Args:
x: the input sequence to the _MelResNet layer (required).
specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time).
Shape:
- x: :math:`(N, S, T)`.
- output: :math:`(N, P, T - 2 * pad)`.
where N is the batch size, S is the number of input sequence,
P is the number of output sequence, T is the length of input sequence.
Return:
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
"""

return self.melresnet_model(x)
return self.melresnet_model(specgram)