Skip to content

Commit 878d3da

Browse files
jimchen90Ji Chen
andauthored
Update MelResNet (#751)
* update varible names and docstring * update format * update docsting and output value Co-authored-by: Ji Chen <[email protected]>
1 parent 4daf2fb commit 878d3da

File tree

2 files changed

+56
-72
lines changed

2 files changed

+56
-72
lines changed

test/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,20 @@ def test_mfcc(self):
3636
class TestMelResNet(common_utils.TorchaudioTestCase):
3737

3838
def test_waveform(self):
39+
"""Validate the output dimensions of a _MelResNet block.
40+
"""
3941

40-
batch_size = 2
41-
num_features = 200
42-
input_dims = 100
43-
output_dims = 128
44-
res_blocks = 10
45-
hidden_dims = 128
46-
pad = 2
42+
n_batch = 2
43+
n_time = 200
44+
n_freq = 100
45+
n_output = 128
46+
n_res_block = 10
47+
n_hidden = 128
48+
kernel_size = 5
4749

48-
model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad)
50+
model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
4951

50-
x = torch.rand(batch_size, input_dims, num_features)
52+
x = torch.rand(n_batch, n_freq, n_time)
5153
out = model(x)
5254

53-
assert out.size() == (batch_size, output_dims, num_features - pad * 2)
55+
assert out.size() == (n_batch, n_output, n_time - kernel_size + 1)

torchaudio/models/_wavernn.py

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,101 +5,83 @@
55

66

77
class _ResBlock(nn.Module):
8-
r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning
9-
for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016.
10-
It is a block used in WaveRNN. WaveRNN is based on the paper "Efficient Neural Audio Synthesis".
11-
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart,
12-
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
8+
r"""ResNet block based on "Deep Residual Learning for Image Recognition"
9+
10+
The paper link is https://arxiv.org/pdf/1512.03385.pdf.
1311
1412
Args:
15-
num_dims: the number of compute dimensions in the input (default=128).
13+
n_freq: the number of bins in a spectrogram (default=128)
1614
17-
Examples::
18-
>>> resblock = _ResBlock(num_dims=128)
19-
>>> input = torch.rand(10, 128, 512)
20-
>>> output = resblock(input)
15+
Examples
16+
>>> resblock = _ResBlock()
17+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
18+
>>> output = resblock(input) # shape: (10, 128, 512)
2119
"""
2220

23-
def __init__(self, num_dims: int = 128) -> None:
21+
def __init__(self, n_freq: int = 128) -> None:
2422
super().__init__()
2523

2624
self.resblock_model = nn.Sequential(
27-
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
28-
nn.BatchNorm1d(num_dims),
25+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
26+
nn.BatchNorm1d(n_freq),
2927
nn.ReLU(inplace=True),
30-
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
31-
nn.BatchNorm1d(num_dims)
28+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
29+
nn.BatchNorm1d(n_freq)
3230
)
3331

34-
def forward(self, x: Tensor) -> Tensor:
32+
def forward(self, specgram: Tensor) -> Tensor:
3533
r"""Pass the input through the _ResBlock layer.
36-
3734
Args:
38-
x: the input sequence to the _ResBlock layer (required).
35+
specgram (Tensor): the input sequence to the _ResBlock layer (n_batch, n_freq, n_time).
3936
40-
Shape:
41-
- x: :math:`(N, S, T)`.
42-
- output: :math:`(N, S, T)`.
43-
where N is the batch size, S is the number of input sequence,
44-
T is the length of input sequence.
37+
Return:
38+
Tensor shape: (n_batch, n_freq, n_time)
4539
"""
4640

47-
residual = x
48-
return self.resblock_model(x) + residual
41+
return self.resblock_model(specgram) + specgram
4942

5043

5144
class _MelResNet(nn.Module):
52-
r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN.
53-
WaveRNN is based on the paper "Efficient Neural Audio Synthesis". Nal Kalchbrenner, Erich Elsen,
54-
Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, Florian Stimberg, Aaron van den Oord,
55-
Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
45+
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
5646
5747
Args:
58-
res_blocks: the number of ResBlock in stack (default=10).
59-
input_dims: the number of input sequence (default=100).
60-
hidden_dims: the number of compute dimensions (default=128).
61-
output_dims: the number of output sequence (default=128).
62-
pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2).
63-
64-
Examples::
65-
>>> melresnet = _MelResNet(res_blocks=10, input_dims=100,
66-
hidden_dims=128, output_dims=128, pad=2)
67-
>>> input = torch.rand(10, 100, 512)
68-
>>> output = melresnet(input)
48+
n_res_block: the number of ResBlock in stack (default=10)
49+
n_freq: the number of bins in a spectrogram (default=128)
50+
n_hidden: the number of hidden dimensions (default=128)
51+
n_output: the number of output dimensions (default=128)
52+
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
53+
54+
Examples
55+
>>> melresnet = _MelResNet()
56+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
57+
>>> output = melresnet(input) # shape: (10, 128, 508)
6958
"""
7059

71-
def __init__(self, res_blocks: int = 10,
72-
input_dims: int = 100,
73-
hidden_dims: int = 128,
74-
output_dims: int = 128,
75-
pad: int = 2) -> None:
60+
def __init__(self,
61+
n_res_block: int = 10,
62+
n_freq: int = 128,
63+
n_hidden: int = 128,
64+
n_output: int = 128,
65+
kernel_size: int = 5) -> None:
7666
super().__init__()
7767

78-
kernel_size = pad * 2 + 1
79-
ResBlocks = []
80-
81-
for i in range(res_blocks):
82-
ResBlocks.append(_ResBlock(hidden_dims))
68+
ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]
8369

8470
self.melresnet_model = nn.Sequential(
85-
nn.Conv1d(in_channels=input_dims, out_channels=hidden_dims, kernel_size=kernel_size, bias=False),
86-
nn.BatchNorm1d(hidden_dims),
71+
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
72+
nn.BatchNorm1d(n_hidden),
8773
nn.ReLU(inplace=True),
8874
*ResBlocks,
89-
nn.Conv1d(in_channels=hidden_dims, out_channels=output_dims, kernel_size=1)
75+
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1)
9076
)
9177

92-
def forward(self, x: Tensor) -> Tensor:
78+
def forward(self, specgram: Tensor) -> Tensor:
9379
r"""Pass the input through the _MelResNet layer.
94-
9580
Args:
96-
x: the input sequence to the _MelResNet layer (required).
81+
specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time).
9782
98-
Shape:
99-
- x: :math:`(N, S, T)`.
100-
- output: :math:`(N, P, T - 2 * pad)`.
101-
where N is the batch size, S is the number of input sequence,
102-
P is the number of output sequence, T is the length of input sequence.
83+
Return:
84+
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
10385
"""
10486

105-
return self.melresnet_model(x)
87+
return self.melresnet_model(specgram)

0 commit comments

Comments
 (0)