diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index cd2e89a10c..6a3108d6f3 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -4,7 +4,16 @@ from torch import Tensor from torch import nn -__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] +from .utils import load_state_dict_from_url + + +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN", "_wavernn"] + + +model_urls = { + # FIXME Replace URL by final one once determined + '_wavernn': 'https://download.pytorch.org/models/_wavernn.pth', +} class _ResBlock(nn.Module): @@ -329,3 +338,19 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: # bring back channel dimension return x.unsqueeze(1) + + +def _wavernn(pretrained=False, progress=True, **kwargs): + r"""WaveRNN model based on the implementation from + `fatchord `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on LJSpeech + progress (bool): If True, displays a progress bar of the download to stderr + """ + model = _WaveRNN(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['_wavernn'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchaudio/models/utils.py b/torchaudio/models/utils.py new file mode 100644 index 0000000000..638ef07cd8 --- /dev/null +++ b/torchaudio/models/utils.py @@ -0,0 +1,4 @@ +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url