Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion torchaudio/models/_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from torch import Tensor
from torch import nn

__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]
from .utils import load_state_dict_from_url


__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN", "_wavernn"]


model_urls = {
# FIXME Replace URL by final one once determined
'_wavernn': 'https://download.pytorch.org/models/_wavernn.pth',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's conventional to add a FIXME comment here

    # FIXME Replace URL by final one once determined
    '_wavernn': 'https://download.pytorch.org/models/_wavernn.pth',

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: once you know the actual URL, you can add it here without the underscore in the URL :)

}


class _ResBlock(nn.Module):
Expand Down Expand Up @@ -329,3 +338,19 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:

# bring back channel dimension
return x.unsqueeze(1)


def _wavernn(pretrained=False, progress=True, **kwargs):
r"""WaveRNN model based on the implementation from
`fatchord <https://github.com/fatchord/WaveRNN>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on LJSpeech
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = _WaveRNN(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['_wavernn'],
progress=progress)
model.load_state_dict(state_dict)
return model
4 changes: 4 additions & 0 deletions torchaudio/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url