88
99
1010class _ResBlock (nn .Module ):
11- r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning
12- for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016.
13- It is a block used in WaveRNN.
11+ r"""ResNet block layer based on
12+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
1413
1514 Args:
1615 n_freq: the number of bins in a spectrogram (default=128)
@@ -47,7 +46,7 @@ def forward(self, x: Tensor) -> Tensor:
4746
4847
4948class _MelResNet (nn .Module ):
50- r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN .
49+ r"""MelResNet layer based on a stack of ResBlocks.
5150
5251 Args:
5352 n_res_block: the number of ResBlock in stack (default=10)
@@ -71,10 +70,7 @@ def __init__(self,
7170 kernel_size : int = 5 ) -> None :
7271 super ().__init__ ()
7372
74- ResBlocks = []
75-
76- for i in range (n_res_block ):
77- ResBlocks .append (_ResBlock (n_hidden ))
73+ ResBlocks = [_ResBlock (n_hidden ) for _ in range (n_res_block )]
7874
7975 self .melresnet_model = nn .Sequential (
8076 nn .Conv1d (in_channels = n_freq , out_channels = n_hidden , kernel_size = kernel_size , bias = False ),
@@ -98,7 +94,7 @@ def forward(self, x: Tensor) -> Tensor:
9894
9995
10096class _Stretch2d (nn .Module ):
101- r"""This is a two -dimensional stretch layer. It is a block used in WaveRNN .
97+ r"""Two -dimensional stretch layer.
10298
10399 Args:
104100 x_scale: the scale factor in x axis
@@ -133,8 +129,7 @@ def forward(self, x: Tensor) -> Tensor:
133129
134130
135131class _UpsampleNetwork (nn .Module ):
136- r"""This is an upsample block based on a stack of Conv2d and Strech2d layers.
137- It is a block used in WaveRNN.
132+ r"""Upsample block based on a stack of Conv2d and Strech2d layers.
138133
139134 Args:
140135 upsample_scales: the list of upsample scales
@@ -174,11 +169,9 @@ def __init__(self,
174169
175170 up_layers = []
176171 for scale in upsample_scales :
177- k_size = (1 , scale * 2 + 1 )
178- padding = (0 , scale )
179172 stretch = _Stretch2d (scale , 1 )
180- conv = nn .Conv2d (in_channels = 1 , out_channels = 1 , kernel_size = k_size , padding = padding , bias = False )
181- conv .weight .data .fill_ (1. / k_size [ 1 ] )
173+ conv = nn .Conv2d (in_channels = 1 , out_channels = 1 , kernel_size = ( 1 , scale * 2 + 1 ), padding = ( 0 , scale ) , bias = False )
174+ conv .weight .data .fill_ (1. / ( scale * 2 + 1 ) )
182175 up_layers .append (stretch )
183176 up_layers .append (conv )
184177 self .upsample_layers = nn .Sequential (* up_layers )
@@ -207,7 +200,9 @@ def forward(self, x: Tensor) -> Tensor:
207200
208201
209202class _WaveRNN (nn .Module ):
210- r"""
203+ r"""WaveRNN model based on
204+ `"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_
205+
211206 Args:
212207 upsample_scales: the list of upsample scales
213208 n_bits: the bits of output waveform
@@ -220,7 +215,7 @@ class _WaveRNN(nn.Module):
220215 n_freq: the number of bins in a spectrogram (default=128)
221216 n_hidden: the number of hidden dimensions (default=128)
222217 n_output: the number of output dimensions (default=128)
223- mode: the type of input waveform (default='RAW')
218+ mode: the type of input waveform in ['RAW', 'MOL'] (default='RAW')
224219
225220 Examples::
226221 >>> upsamplenetwork = _waveRNN(upsample_scales=[5,5,8],
@@ -262,6 +257,8 @@ def __init__(self,
262257 self .n_classes = 2 ** n_bits
263258 elif self .mode == 'MOL' :
264259 self .n_classes = 30
260+ else :
261+ raise ValueError ("Unknown input mode - {}" .format (self .mode ))
265262
266263 self .n_rnn = n_rnn
267264 self .n_aux = n_output // 4
@@ -294,8 +291,8 @@ def forward(self, x: Tensor, mels: Tensor) -> Tensor:
294291 """
295292
296293 batch_size = x .size (0 )
297- h1 = torch .zeros (1 , batch_size , self .n_rnn , device = x .device )
298- h2 = torch .zeros (1 , batch_size , self .n_rnn , device = x .device )
294+ h1 = torch .zeros (1 , batch_size , self .n_rnn , dtype = x . dtype , device = x .device )
295+ h2 = torch .zeros (1 , batch_size , self .n_rnn , dtype = x . dtype , device = x .device )
299296 mels , aux = self .upsample (mels )
300297
301298 aux_idx = [self .n_aux * i for i in range (5 )]
0 commit comments