Skip to content

Commit 68cc72d

Browse files
jimchen90Ji Chen
andauthored
Add WaveRNN Model (#735)
* upsamplenetwork * update variable names * update variable name * add wavernn model * update test * update format * update format * update format * fix conflicts and add transpose * import update * update transpose * update format * update docstring * add n_channel in input * add comment * update docstring * update docstring Co-authored-by: Ji Chen <[email protected]>
1 parent ad7f43f commit 68cc72d

File tree

2 files changed

+198
-2
lines changed

2 files changed

+198
-2
lines changed

test/test_models.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork
2+
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN
33

44
from . import common_utils
55

@@ -81,3 +81,62 @@ def test_waveform(self):
8181

8282
assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1))
8383
assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1))
84+
85+
86+
class TestWaveRNN(common_utils.TorchaudioTestCase):
87+
88+
def test_waveform(self):
89+
"""Validate the output dimensions of a _WaveRNN model in waveform mode.
90+
"""
91+
92+
upsample_scales = [5, 5, 8]
93+
n_rnn = 512
94+
n_fc = 512
95+
n_bits = 9
96+
sample_rate = 24000
97+
hop_length = 200
98+
n_batch = 2
99+
n_time = 200
100+
n_freq = 100
101+
n_output = 256
102+
n_res_block = 10
103+
n_hidden = 128
104+
kernel_size = 5
105+
mode = 'waveform'
106+
107+
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
108+
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
109+
110+
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
111+
mels = torch.rand(n_batch, 1, n_freq, n_time)
112+
out = model(x, mels)
113+
114+
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits)
115+
116+
def test_mol(self):
117+
"""Validate the output dimensions of a _WaveRNN model in mol mode.
118+
"""
119+
120+
upsample_scales = [5, 5, 8]
121+
n_rnn = 512
122+
n_fc = 512
123+
n_bits = 9
124+
sample_rate = 24000
125+
hop_length = 200
126+
n_batch = 2
127+
n_time = 200
128+
n_freq = 100
129+
n_output = 256
130+
n_res_block = 10
131+
n_hidden = 128
132+
kernel_size = 5
133+
mode = 'mol'
134+
135+
model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block,
136+
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode)
137+
138+
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
139+
mels = torch.rand(n_batch, 1, n_freq, n_time)
140+
out = model(x, mels)
141+
142+
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30)

torchaudio/models/_wavernn.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import List
22

3+
import torch
34
from torch import Tensor
45
from torch import nn
56

6-
__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"]
7+
__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]
78

89

910
class _ResBlock(nn.Module):
@@ -192,3 +193,139 @@ def forward(self, specgram: Tensor) -> Tensor:
192193
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]
193194

194195
return upsampling_output, resnet_output
196+
197+
198+
class _WaveRNN(nn.Module):
199+
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.
200+
201+
The original implementation was introduced in
202+
`"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_.
203+
The input channels of waveform and spectrogram have to be 1. The product of
204+
`upsample_scales` must equal `hop_length`.
205+
206+
Args:
207+
upsample_scales: the list of upsample scales
208+
n_bits: the bits of output waveform
209+
sample_rate: the rate of audio dimensions (samples per second)
210+
hop_length: the number of samples between the starts of consecutive frames
211+
n_res_block: the number of ResBlock in stack (default=10)
212+
n_rnn: the dimension of RNN layer (default=512)
213+
n_fc: the dimension of fully connected layer (default=512)
214+
kernel_size: the number of kernel size in the first Conv1d layer (default=5)
215+
n_freq: the number of bins in a spectrogram (default=128)
216+
n_hidden: the number of hidden dimensions (default=128)
217+
n_output: the number of output dimensions (default=128)
218+
mode: the mode of waveform in ['waveform', 'mol'] (default='waveform')
219+
220+
Example
221+
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200)
222+
>>> waveform, sample_rate = torchaudio.load(file)
223+
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
224+
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
225+
>>> output = wavernn(waveform, specgram)
226+
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
227+
"""
228+
229+
def __init__(self,
230+
upsample_scales: List[int],
231+
n_bits: int,
232+
sample_rate: int,
233+
hop_length: int,
234+
n_res_block: int = 10,
235+
n_rnn: int = 512,
236+
n_fc: int = 512,
237+
kernel_size: int = 5,
238+
n_freq: int = 128,
239+
n_hidden: int = 128,
240+
n_output: int = 128,
241+
mode: str = 'waveform') -> None:
242+
super().__init__()
243+
244+
self.mode = mode
245+
self.kernel_size = kernel_size
246+
247+
if self.mode == 'waveform':
248+
self.n_classes = 2 ** n_bits
249+
elif self.mode == 'mol':
250+
self.n_classes = 30
251+
else:
252+
raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}")
253+
254+
self.n_rnn = n_rnn
255+
self.n_aux = n_output // 4
256+
self.hop_length = hop_length
257+
self.sample_rate = sample_rate
258+
259+
total_scale = 1
260+
for upsample_scale in upsample_scales:
261+
total_scale *= upsample_scale
262+
if total_scale != self.hop_length:
263+
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
264+
265+
self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
266+
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
267+
268+
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
269+
self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
270+
271+
self.relu1 = nn.ReLU(inplace=True)
272+
self.relu2 = nn.ReLU(inplace=True)
273+
274+
self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
275+
self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
276+
self.fc3 = nn.Linear(n_fc, self.n_classes)
277+
278+
def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
279+
r"""Pass the input through the _WaveRNN model.
280+
281+
Args:
282+
waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
283+
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
284+
285+
Return:
286+
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits)
287+
"""
288+
289+
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
290+
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
291+
# remove channel dimension until the end
292+
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
293+
294+
batch_size = waveform.size(0)
295+
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
296+
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
297+
# output of upsample:
298+
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
299+
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
300+
specgram, aux = self.upsample(specgram)
301+
specgram = specgram.transpose(1, 2)
302+
aux = aux.transpose(1, 2)
303+
304+
aux_idx = [self.n_aux * i for i in range(5)]
305+
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
306+
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
307+
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
308+
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
309+
310+
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
311+
x = self.fc(x)
312+
res = x
313+
x, _ = self.rnn1(x, h1)
314+
315+
x = x + res
316+
res = x
317+
x = torch.cat([x, a2], dim=-1)
318+
x, _ = self.rnn2(x, h2)
319+
320+
x = x + res
321+
x = torch.cat([x, a3], dim=-1)
322+
x = self.fc1(x)
323+
x = self.relu1(x)
324+
325+
x = torch.cat([x, a4], dim=-1)
326+
x = self.fc2(x)
327+
x = self.relu2(x)
328+
x = self.fc3(x)
329+
330+
# bring back channel dimension
331+
return x.unsqueeze(1)

0 commit comments

Comments
 (0)