Skip to content

Commit 0fabf59

Browse files
committed
Add pretrained wavernn
1 parent 284bd10 commit 0fabf59

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

torchaudio/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .wav2letter import Wav2Letter
2-
from .wavernn import WaveRNN
2+
from .wavernn import WaveRNN, wavernn_10k_epochs_8bits_ljspeech
33
from .conv_tasnet import ConvTasNet
44
from .deepspeech import DeepSpeech
55
from .wav2vec2 import (
@@ -13,6 +13,7 @@
1313
__all__ = [
1414
'Wav2Letter',
1515
'WaveRNN',
16+
'wavernn_10k_epochs_8bits_ljspeech',
1617
'ConvTasNet',
1718
'DeepSpeech',
1819
'Wav2Vec2Model',

torchaudio/models/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
try:
2+
from torch.hub import load_state_dict_from_url
3+
except ImportError:
4+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
5+
6+
__all__ = [
7+
'load_state_dict_from_url',
8+
]

torchaudio/models/wavernn.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Any
22

33
import torch
44
from torch import Tensor
55
from torch import nn
66

7+
from ._utils import load_state_dict_from_url
8+
9+
710
__all__ = [
811
"ResBlock",
912
"MelResNet",
1013
"Stretch2d",
1114
"UpsampleNetwork",
1215
"WaveRNN",
16+
"wavernn_10k_epochs_8bits_ljspeech",
1317
]
1418

1519

20+
model_urls = {
21+
'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/'
22+
'audio/wavernn_10k_epochs_8bits_ljspeech.pth',
23+
}
24+
25+
1626
class ResBlock(nn.Module):
1727
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
1828
@@ -324,3 +334,37 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
324334

325335
# bring back channel dimension
326336
return x.unsqueeze(1)
337+
338+
339+
def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN:
340+
model = WaveRNN(**kwargs)
341+
if pretrained:
342+
state_dict = load_state_dict_from_url(model_urls['wavernn'],
343+
progress=progress)
344+
model.load_state_dict(state_dict)
345+
return model
346+
347+
348+
def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN:
349+
r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
350+
The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py.
351+
352+
Args:
353+
pretrained (bool): If True, returns a model pre-trained on LJSpeech
354+
progress (bool): If True, displays a progress bar of the download to stderr
355+
"""
356+
n_bits = 8
357+
configs = {
358+
'upsample_scales': [5, 5, 11],
359+
'n_classes': 2 ** n_bits,
360+
'hop_length': 275,
361+
'n_res_block': 10,
362+
'n_rnn': 512,
363+
'n_fc': 512,
364+
'kernel_size': 5,
365+
'n_freq': 80,
366+
'n_hidden': 128,
367+
'n_output': 128
368+
}
369+
configs.update(kwargs)
370+
return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs)

0 commit comments

Comments
 (0)