Skip to content

Commit 7f53c53

Browse files
author
Ji Chen
committed
update format
1 parent f206a8c commit 7f53c53

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

torchaudio/models/_wavernn.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def forward(self, specgram: Tensor) -> Tensor:
198198
class _WaveRNN(nn.Module):
199199
r"""WaveRNN model based on "Efficient Neural Audio Synthesis".
200200
201-
The paper link is https://arxiv.org/pdf/1802.08435.pdf. The input channels of waveform
202-
and spectrogram have to be 1. The product of upsample_scales must equal hop_length.
201+
The paper link is `<https://arxiv.org/pdf/1802.08435.pdf>`_. The input channels of waveform
202+
and spectrogram have to be 1. The product of `upsample_scales` must equal `hop_length`.
203203
204204
Args:
205205
upsample_scales: the list of upsample scales
@@ -217,11 +217,12 @@ class _WaveRNN(nn.Module):
217217
218218
Examples
219219
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
220-
>>> waveform, sample_rate = torchaudio.load(file) # waveform shape:
221-
>>> (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
220+
>>> waveform, sample_rate = torchaudio.load(file)
221+
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
222222
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
223-
>>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1)) # shape:
224-
>>> (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
223+
>>> output = wavernn(waveform.squeeze(1), specgram.squeeze(1))
224+
>>> # output shape in 'waveform' mode: (n_batch, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
225+
>>> # output shape in 'mol' mode: (n_batch, (n_time - kernel_size + 1) * hop_length, 30)
225226
"""
226227

227228
def __init__(self,
@@ -287,8 +288,11 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
287288
batch_size = waveform.size(0)
288289
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
289290
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
290-
mels, aux = self.upsample(specgram)
291-
mels = mels.transpose(1, 2)
291+
# output of upsample:
292+
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
293+
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
294+
specgram, aux = self.upsample(specgram)
295+
specgram = specgram.transpose(1, 2)
292296
aux = aux.transpose(1, 2)
293297

294298
aux_idx = [self.n_aux * i for i in range(5)]
@@ -297,21 +301,23 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
297301
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
298302
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
299303

300-
x = torch.cat([waveform.unsqueeze(-1), mels, a1], dim=2)
304+
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
301305
x = self.fc(x)
302306
res = x
303307
x, _ = self.rnn1(x, h1)
304308

305309
x = x + res
306310
res = x
307-
x = torch.cat([x, a2], dim=2)
311+
x = torch.cat([x, a2], dim=-1)
308312
x, _ = self.rnn2(x, h2)
309313

310314
x = x + res
311-
x = torch.cat([x, a3], dim=2)
312-
x = self.relu1(self.fc1(x))
315+
x = torch.cat([x, a3], dim=-1)
316+
x = self.fc1(x)
317+
x = self.relu1(x)
313318

314-
x = torch.cat([x, a4], dim=2)
315-
x = self.relu2(self.fc2(x))
319+
x = torch.cat([x, a4], dim=-1)
320+
x = self.fc2(x)
321+
x = self.relu2(x)
316322

317323
return self.fc3(x)

0 commit comments

Comments
 (0)