@@ -198,8 +198,8 @@ def forward(self, specgram: Tensor) -> Tensor:
198198class _WaveRNN (nn .Module ):
199199 r"""WaveRNN model based on "Efficient Neural Audio Synthesis".
200200
201- The paper link is https://arxiv.org/pdf/1802.08435.pdf. The input channels of waveform
202- and spectrogram have to be 1. The product of upsample_scales must equal hop_length.
201+ The paper link is `< https://arxiv.org/pdf/1802.08435.pdf>`_ . The input channels of waveform
202+ and spectrogram have to be 1. The product of ` upsample_scales` must equal ` hop_length` .
203203
204204 Args:
205205 upsample_scales: the list of upsample scales
@@ -217,11 +217,12 @@ class _WaveRNN(nn.Module):
217217
218218 Examples
219219 >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
220- >>> waveform, sample_rate = torchaudio.load(file) # waveform shape:
221- >>> (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
220+ >>> waveform, sample_rate = torchaudio.load(file)
221+ >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
222222 >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
223- >>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1)) # shape:
224- >>> (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
223+ >>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1))
224+ >>> # output shape in 'waveform' mode: (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
225+ >>> # output shape in 'mol' mode: (n_batch, (n_time - kernel_size + 1) * hop_length, 30)
225226 """
226227
227228 def __init__ (self ,
@@ -287,8 +288,11 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
287288 batch_size = waveform .size (0 )
288289 h1 = torch .zeros (1 , batch_size , self .n_rnn , dtype = waveform .dtype , device = waveform .device )
289290 h2 = torch .zeros (1 , batch_size , self .n_rnn , dtype = waveform .dtype , device = waveform .device )
290- mels , aux = self .upsample (specgram )
291- mels = mels .transpose (1 , 2 )
291+ # output of upsample:
292+ # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
293+ # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
294+ specgram , aux = self .upsample (specgram )
295+ specgram = specgram .transpose (1 , 2 )
292296 aux = aux .transpose (1 , 2 )
293297
294298 aux_idx = [self .n_aux * i for i in range (5 )]
@@ -297,21 +301,23 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
297301 a3 = aux [:, :, aux_idx [2 ]:aux_idx [3 ]]
298302 a4 = aux [:, :, aux_idx [3 ]:aux_idx [4 ]]
299303
300- x = torch .cat ([waveform .unsqueeze (- 1 ), mels , a1 ], dim = 2 )
304+ x = torch .cat ([waveform .unsqueeze (- 1 ), specgram , a1 ], dim = - 1 )
301305 x = self .fc (x )
302306 res = x
303307 x , _ = self .rnn1 (x , h1 )
304308
305309 x = x + res
306310 res = x
307- x = torch .cat ([x , a2 ], dim = 2 )
311+ x = torch .cat ([x , a2 ], dim = - 1 )
308312 x , _ = self .rnn2 (x , h2 )
309313
310314 x = x + res
311- x = torch .cat ([x , a3 ], dim = 2 )
312- x = self .relu1 (self .fc1 (x ))
315+ x = torch .cat ([x , a3 ], dim = - 1 )
316+ x = self .fc1 (x )
317+ x = self .relu1 (x )
313318
314- x = torch .cat ([x , a4 ], dim = 2 )
315- x = self .relu2 (self .fc2 (x ))
319+ x = torch .cat ([x , a4 ], dim = - 1 )
320+ x = self .fc2 (x )
321+ x = self .relu2 (x )
316322
317323 return self .fc3 (x )
0 commit comments