|
1 | | -from typing import List, Tuple |
| 1 | +from typing import List, Tuple, Any |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | from torch import Tensor |
5 | 5 | from torch import nn |
6 | 6 |
|
| 7 | +from ._utils import load_state_dict_from_url |
| 8 | + |
| 9 | + |
7 | 10 | __all__ = [ |
8 | 11 | "ResBlock", |
9 | 12 | "MelResNet", |
10 | 13 | "Stretch2d", |
11 | 14 | "UpsampleNetwork", |
12 | 15 | "WaveRNN", |
| 16 | + "wavernn_10k_epochs_8bits_ljspeech", |
13 | 17 | ] |
14 | 18 |
|
15 | 19 |
|
| 20 | +model_urls = { |
| 21 | + 'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/' |
| 22 | + 'audio/wavernn_10k_epochs_8bits_ljspeech.pth', |
| 23 | +} |
| 24 | + |
| 25 | + |
16 | 26 | class ResBlock(nn.Module): |
17 | 27 | r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. |
18 | 28 |
|
@@ -324,3 +334,37 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: |
324 | 334 |
|
325 | 335 | # bring back channel dimension |
326 | 336 | return x.unsqueeze(1) |
| 337 | + |
| 338 | + |
| 339 | +def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN: |
| 340 | + model = WaveRNN(**kwargs) |
| 341 | + if pretrained: |
| 342 | + state_dict = load_state_dict_from_url(model_urls['wavernn'], |
| 343 | + progress=progress) |
| 344 | + model.load_state_dict(state_dict) |
| 345 | + return model |
| 346 | + |
| 347 | + |
| 348 | +def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN: |
| 349 | + r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. |
| 350 | + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. |
| 351 | +
|
| 352 | + Args: |
| 353 | + pretrained (bool): If True, returns a model pre-trained on LJSpeech |
| 354 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 355 | + """ |
| 356 | + n_bits = 8 |
| 357 | + configs = { |
| 358 | + 'upsample_scales': [5, 5, 11], |
| 359 | + 'n_classes': 2 ** n_bits, |
| 360 | + 'hop_length': 275, |
| 361 | + 'n_res_block': 10, |
| 362 | + 'n_rnn': 512, |
| 363 | + 'n_fc': 512, |
| 364 | + 'kernel_size': 5, |
| 365 | + 'n_freq': 80, |
| 366 | + 'n_hidden': 128, |
| 367 | + 'n_output': 128 |
| 368 | + } |
| 369 | + configs.update(kwargs) |
| 370 | + return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs) |
0 commit comments