Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,15 @@ WaveRNN

.. automethod:: forward

Factory Functions
-----------------

wavernn
-------

.. autofunction:: wavernn

References
~~~~~~~~~~

.. footbibliography::

3 changes: 2 additions & 1 deletion torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .wavernn import WaveRNN, wavernn
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .wav2vec2 import (
Expand All @@ -13,6 +13,7 @@
__all__ = [
'Wav2Letter',
'WaveRNN',
'wavernn',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
Expand Down
49 changes: 48 additions & 1 deletion torchaudio/models/wavernn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
from typing import List, Tuple
from typing import List, Tuple, Dict, Any

import torch
from torch import Tensor
from torch import nn
from torch.hub import load_state_dict_from_url


__all__ = [
"ResBlock",
"MelResNet",
"Stretch2d",
"UpsampleNetwork",
"WaveRNN",
"wavernn",
]


_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}


class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].

Expand Down Expand Up @@ -324,3 +346,28 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:

# bring back channel dimension
return x.unsqueeze(1)


def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.

Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:

- ``"wavernn_10k_epochs_8bits_ljspeech"``:

WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")

url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model