From 3dac50a2c5a9bf59fd5f670358a9c45470398683 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Sun, 12 Jul 2020 12:39:33 -0700 Subject: [PATCH 1/5] Add pretrained model --- torchaudio/models/_wavernn.py | 25 ++++++++++++++++++++++++- torchaudio/models/utils.py | 4 ++++ 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 torchaudio/models/utils.py diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index cd2e89a10c..5fcdc3050b 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -3,8 +3,15 @@ import torch from torch import Tensor from torch import nn +from .utils import load_state_dict_from_url -__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"] + +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN", "_wavernn"] + + +model_urls = { + '_wavernn': 'https://download.pytorch.org/models/_wavernn.pth', +} class _ResBlock(nn.Module): @@ -329,3 +336,19 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: # bring back channel dimension return x.unsqueeze(1) + + +def _wavernn(pretrained=False, progress=True, **kwargs): + r"""WaveRNN model based on the implementation from + `fatchord `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on LJSpeech + progress (bool): If True, displays a progress bar of the download to stderr + """ + model = _WaveRNN(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['_wavernn'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchaudio/models/utils.py b/torchaudio/models/utils.py new file mode 100644 index 0000000000..638ef07cd8 --- /dev/null +++ b/torchaudio/models/utils.py @@ -0,0 +1,4 @@ +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url From 37d4322be53746370954db8ad6b332a8d9c64477 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 11:08:16 -0700 Subject: [PATCH 2/5] Add model url --- torchaudio/models/_wavernn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 5fcdc3050b..9c5b58557d 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch import nn + from .utils import load_state_dict_from_url @@ -10,7 +11,7 @@ model_urls = { - '_wavernn': 'https://download.pytorch.org/models/_wavernn.pth', + '_wavernn': 'https://ossci-assets.s3.amazonaws.com/torchaudio/wavernn_8bits_waveform_ljspeech.pth', } From 04c37db3818ca43972c848c0e87eb09f14382b21 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 11:16:31 -0700 Subject: [PATCH 3/5] update utils import --- torchaudio/models/_wavernn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 9c5b58557d..339662295e 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -4,7 +4,7 @@ from torch import Tensor from torch import nn -from .utils import load_state_dict_from_url +from utils import load_state_dict_from_url __all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN", "_wavernn"] From 7e46e8c86d16fb5d5e4d0e7d141529760fff5e40 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 13:27:47 -0700 Subject: [PATCH 4/5] fix utils import --- torchaudio/models/_wavernn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 339662295e..9c5b58557d 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -4,7 +4,7 @@ from torch import Tensor from torch import nn -from utils import load_state_dict_from_url +from .utils import load_state_dict_from_url __all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN", "_wavernn"] From 84a88310be467aea57f5403474a27ea71f43a2b2 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 15 Jul 2020 14:10:27 -0700 Subject: [PATCH 5/5] update url --- torchaudio/models/_wavernn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 9c5b58557d..6a3108d6f3 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -11,7 +11,8 @@ model_urls = { - '_wavernn': 'https://ossci-assets.s3.amazonaws.com/torchaudio/wavernn_8bits_waveform_ljspeech.pth', + # FIXME Replace URL by final one once determined + '_wavernn': 'https://download.pytorch.org/models/_wavernn.pth', }