Skip to content

Commit a8528bf

Browse files
author
Ji Chen
committed
update format
1 parent 78d1efd commit a8528bf

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

test/test_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ def test_mfcc(self):
3636
class TestMelResNet(common_utils.TorchaudioTestCase):
3737

3838
def test_waveform(self):
39+
"""
40+
Create a tensor as the input of _MelResNet layer
41+
and test if the output dimensions are correct.
42+
"""
43+
3944
batch_size = 2
4045
n_time = 200
4146
n_freq = 100
@@ -55,6 +60,10 @@ def test_waveform(self):
5560
class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
5661

5762
def test_waveform(self):
63+
"""
64+
Create a tensor as the input of _UpsampleNetwork block
65+
and test if the output dimensions are correct.
66+
"""
5867

5968
upsample_scales = [5, 5, 8]
6069
batch_size = 2
@@ -81,6 +90,10 @@ def test_waveform(self):
8190
class TestWaveRNN(common_utils.TorchaudioTestCase):
8291

8392
def test_waveform(self):
93+
"""
94+
Create a tensor as the input of _WaveRNN model
95+
and test if the output dimensions are correct.
96+
"""
8497

8598
upsample_scales = [5, 5, 8]
8699
n_rnn = 512

torchaudio/models/_wavernn.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99

1010
class _ResBlock(nn.Module):
11-
r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning
12-
for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016.
13-
It is a block used in WaveRNN.
11+
r"""ResNet block layer based on
12+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
1413
1514
Args:
1615
n_freq: the number of bins in a spectrogram (default=128)
@@ -47,7 +46,7 @@ def forward(self, x: Tensor) -> Tensor:
4746

4847

4948
class _MelResNet(nn.Module):
50-
r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN.
49+
r"""MelResNet layer based on a stack of ResBlocks.
5150
5251
Args:
5352
n_res_block: the number of ResBlock in stack (default=10)
@@ -71,10 +70,7 @@ def __init__(self,
7170
kernel_size: int = 5) -> None:
7271
super().__init__()
7372

74-
ResBlocks = []
75-
76-
for i in range(n_res_block):
77-
ResBlocks.append(_ResBlock(n_hidden))
73+
ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]
7874

7975
self.melresnet_model = nn.Sequential(
8076
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
@@ -98,7 +94,7 @@ def forward(self, x: Tensor) -> Tensor:
9894

9995

10096
class _Stretch2d(nn.Module):
101-
r"""This is a two-dimensional stretch layer. It is a block used in WaveRNN.
97+
r"""Two-dimensional stretch layer.
10298
10399
Args:
104100
x_scale: the scale factor in x axis
@@ -133,8 +129,7 @@ def forward(self, x: Tensor) -> Tensor:
133129

134130

135131
class _UpsampleNetwork(nn.Module):
136-
r"""This is an upsample block based on a stack of Conv2d and Strech2d layers.
137-
It is a block used in WaveRNN.
132+
r"""Upsample block based on a stack of Conv2d and Strech2d layers.
138133
139134
Args:
140135
upsample_scales: the list of upsample scales
@@ -174,11 +169,9 @@ def __init__(self,
174169

175170
up_layers = []
176171
for scale in upsample_scales:
177-
k_size = (1, scale * 2 + 1)
178-
padding = (0, scale)
179172
stretch = _Stretch2d(scale, 1)
180-
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=k_size, padding=padding, bias=False)
181-
conv.weight.data.fill_(1. / k_size[1])
173+
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False)
174+
conv.weight.data.fill_(1. / (scale * 2 + 1))
182175
up_layers.append(stretch)
183176
up_layers.append(conv)
184177
self.upsample_layers = nn.Sequential(*up_layers)
@@ -207,7 +200,9 @@ def forward(self, x: Tensor) -> Tensor:
207200

208201

209202
class _WaveRNN(nn.Module):
210-
r"""
203+
r"""WaveRNN model based on
204+
`"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_
205+
211206
Args:
212207
upsample_scales: the list of upsample scales
213208
n_bits: the bits of output waveform
@@ -220,7 +215,7 @@ class _WaveRNN(nn.Module):
220215
n_freq: the number of bins in a spectrogram (default=128)
221216
n_hidden: the number of hidden dimensions (default=128)
222217
n_output: the number of output dimensions (default=128)
223-
mode: the type of input waveform (default='RAW')
218+
mode: the type of input waveform in ['RAW', 'MOL'] (default='RAW')
224219
225220
Examples::
226221
>>> upsamplenetwork = _waveRNN(upsample_scales=[5,5,8],
@@ -262,6 +257,8 @@ def __init__(self,
262257
self.n_classes = 2 ** n_bits
263258
elif self.mode == 'MOL':
264259
self.n_classes = 30
260+
else:
261+
raise ValueError("Unknown input mode - {}".format(self.mode))
265262

266263
self.n_rnn = n_rnn
267264
self.n_aux = n_output // 4
@@ -294,8 +291,8 @@ def forward(self, x: Tensor, mels: Tensor) -> Tensor:
294291
"""
295292

296293
batch_size = x.size(0)
297-
h1 = torch.zeros(1, batch_size, self.n_rnn, device=x.device)
298-
h2 = torch.zeros(1, batch_size, self.n_rnn, device=x.device)
294+
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=x.dtype, device=x.device)
295+
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=x.dtype, device=x.device)
299296
mels, aux = self.upsample(mels)
300297

301298
aux_idx = [self.n_aux * i for i in range(5)]

0 commit comments

Comments
 (0)