From 5e067d85257adf3f280f0963d5f7b59e7e084f6e Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Thu, 24 Jun 2021 21:15:54 +0000 Subject: [PATCH 1/5] Add pretrained wavernn --- torchaudio/models/__init__.py | 3 ++- torchaudio/models/_utils.py | 8 ++++++ torchaudio/models/wavernn.py | 46 ++++++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 torchaudio/models/_utils.py diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 843f6d15f5..ae4a2993cd 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,5 +1,5 @@ from .wav2letter import Wav2Letter -from .wavernn import WaveRNN +from .wavernn import WaveRNN, wavernn_10k_epochs_8bits_ljspeech from .conv_tasnet import ConvTasNet from .deepspeech import DeepSpeech from .wav2vec2 import ( @@ -13,6 +13,7 @@ __all__ = [ 'Wav2Letter', 'WaveRNN', + 'wavernn_10k_epochs_8bits_ljspeech', 'ConvTasNet', 'DeepSpeech', 'Wav2Vec2Model', diff --git a/torchaudio/models/_utils.py b/torchaudio/models/_utils.py new file mode 100644 index 0000000000..3f4800dc75 --- /dev/null +++ b/torchaudio/models/_utils.py @@ -0,0 +1,8 @@ +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +__all__ = [ + 'load_state_dict_from_url', +] diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 89c1e9d430..065657b3c5 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -1,18 +1,28 @@ -from typing import List, Tuple +from typing import List, Tuple, Any import torch from torch import Tensor from torch import nn +from ._utils import load_state_dict_from_url + + __all__ = [ "ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN", + "wavernn_10k_epochs_8bits_ljspeech", ] +model_urls = { + 'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/' + 'audio/wavernn_10k_epochs_8bits_ljspeech.pth', +} + + class ResBlock(nn.Module): r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. @@ -324,3 +334,37 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: # bring back channel dimension return x.unsqueeze(1) + + +def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN: + model = WaveRNN(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['wavernn'], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN: + r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + + Args: + pretrained (bool): If True, returns a model pre-trained on LJSpeech + progress (bool): If True, displays a progress bar of the download to stderr + """ + n_bits = 8 + configs = { + 'upsample_scales': [5, 5, 11], + 'n_classes': 2 ** n_bits, + 'hop_length': 275, + 'n_res_block': 10, + 'n_rnn': 512, + 'n_fc': 512, + 'kernel_size': 5, + 'n_freq': 80, + 'n_hidden': 128, + 'n_output': 128 + } + configs.update(kwargs) + return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs) From 330e3295b8473714e1ca68cca7f828022a4b871f Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Thu, 8 Jul 2021 18:33:14 +0000 Subject: [PATCH 2/5] Refactor the pretrained wavernn interface --- torchaudio/models/__init__.py | 4 +-- torchaudio/models/wavernn.py | 67 ++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index ae4a2993cd..1a1a85d874 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,5 +1,5 @@ from .wav2letter import Wav2Letter -from .wavernn import WaveRNN, wavernn_10k_epochs_8bits_ljspeech +from .wavernn import WaveRNN, get_pretrained_wavernn from .conv_tasnet import ConvTasNet from .deepspeech import DeepSpeech from .wav2vec2 import ( @@ -13,7 +13,7 @@ __all__ = [ 'Wav2Letter', 'WaveRNN', - 'wavernn_10k_epochs_8bits_ljspeech', + 'get_pretrained_wavernn', 'ConvTasNet', 'DeepSpeech', 'Wav2Vec2Model', diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 065657b3c5..78c93fa0ce 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any +from typing import List, Tuple, Dict, Any import torch from torch import Tensor @@ -13,13 +13,26 @@ "Stretch2d", "UpsampleNetwork", "WaveRNN", - "wavernn_10k_epochs_8bits_ljspeech", + "get_pretrained_wavernn", ] -model_urls = { - 'wavernn_10k_epochs_8bits_ljspeech': 'https://download.pytorch.org/models/' - 'audio/wavernn_10k_epochs_8bits_ljspeech.pth', +model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = { + 'wavernn_10k_epochs_8bits_ljspeech': ( + 'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth', + { + 'upsample_scales': [5, 5, 11], + 'n_classes': 2 ** 8, # n_bits = 8 + 'hop_length': 275, + 'n_res_block': 10, + 'n_rnn': 512, + 'n_fc': 512, + 'kernel_size': 5, + 'n_freq': 80, + 'n_hidden': 128, + 'n_output': 128 + } + ) } @@ -336,35 +349,25 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: return x.unsqueeze(1) -def _wavernn(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> WaveRNN: - model = WaveRNN(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls['wavernn'], - progress=progress) - model.load_state_dict(state_dict) - return model +def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveRNN: + r"""Get pretrained WaveRNN model. + + Here are the available checkpoints: + - wavernn_10k_epochs_8bits_ljspeech -def wavernn_10k_epochs_8bits_ljspeech(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> WaveRNN: - r"""WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. - The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. Args: - pretrained (bool): If True, returns a model pre-trained on LJSpeech - progress (bool): If True, displays a progress bar of the download to stderr + checkpoint_name (str): The name of the checkpoint to load. + progress (bool): If True, displays a progress bar of the download to stderr. """ - n_bits = 8 - configs = { - 'upsample_scales': [5, 5, 11], - 'n_classes': 2 ** n_bits, - 'hop_length': 275, - 'n_res_block': 10, - 'n_rnn': 512, - 'n_fc': 512, - 'kernel_size': 5, - 'n_freq': 80, - 'n_hidden': 128, - 'n_output': 128 - } - configs.update(kwargs) - return _wavernn("wavernn_10k_epochs_8bits_ljspeech", pretrained=pretrained, progress=progress, **configs) + if checkpoint_name in model_config_and_urls: + url, configs = model_config_and_urls[checkpoint_name] + model = WaveRNN(**configs) + state_dict = load_state_dict_from_url(url, progress=progress) + model.load_state_dict(state_dict) + return model + else: + raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name)) From 8f0466d5ed96617b5d3fe571650ae9f427449922 Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Mon, 12 Jul 2021 18:25:08 +0000 Subject: [PATCH 3/5] Fix a few coding style --- torchaudio/models/_utils.py | 8 -------- torchaudio/models/wavernn.py | 24 +++++++++++++----------- 2 files changed, 13 insertions(+), 19 deletions(-) delete mode 100644 torchaudio/models/_utils.py diff --git a/torchaudio/models/_utils.py b/torchaudio/models/_utils.py deleted file mode 100644 index 3f4800dc75..0000000000 --- a/torchaudio/models/_utils.py +++ /dev/null @@ -1,8 +0,0 @@ -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url - -__all__ = [ - 'load_state_dict_from_url', -] diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 78c93fa0ce..f5142382b6 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -3,8 +3,10 @@ import torch from torch import Tensor from torch import nn - -from ._utils import load_state_dict_from_url +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url __all__ = [ @@ -17,7 +19,7 @@ ] -model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = { +_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = { 'wavernn_10k_epochs_8bits_ljspeech': ( 'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth', { @@ -363,11 +365,11 @@ def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveR checkpoint_name (str): The name of the checkpoint to load. progress (bool): If True, displays a progress bar of the download to stderr. """ - if checkpoint_name in model_config_and_urls: - url, configs = model_config_and_urls[checkpoint_name] - model = WaveRNN(**configs) - state_dict = load_state_dict_from_url(url, progress=progress) - model.load_state_dict(state_dict) - return model - else: - raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name)) + if checkpoint_name not in _MODEL_CONFIG_AND_URLS: + raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name)) + + url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name] + model = WaveRNN(**configs) + state_dict = load_state_dict_from_url(url, progress=progress) + model.load_state_dict(state_dict) + return model From e53390b2ced0913c4b5fd5cf21ea1b8e5ba79aaf Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Mon, 12 Jul 2021 21:51:58 +0000 Subject: [PATCH 4/5] Add pretrained WaveRNN to docs --- docs/source/models.rst | 9 ++++++++- torchaudio/models/__init__.py | 4 ++-- torchaudio/models/wavernn.py | 28 ++++++++++++---------------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 39e162baa0..c0d70e2f7c 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -88,8 +88,15 @@ WaveRNN .. automethod:: forward +Factory Functions +----------------- + +wavernn +------- + +.. autofunction:: wavernn + References ~~~~~~~~~~ .. footbibliography:: - diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 1a1a85d874..af622b0a73 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,5 +1,5 @@ from .wav2letter import Wav2Letter -from .wavernn import WaveRNN, get_pretrained_wavernn +from .wavernn import WaveRNN, wavernn from .conv_tasnet import ConvTasNet from .deepspeech import DeepSpeech from .wav2vec2 import ( @@ -13,7 +13,7 @@ __all__ = [ 'Wav2Letter', 'WaveRNN', - 'get_pretrained_wavernn', + 'wavernn', 'ConvTasNet', 'DeepSpeech', 'Wav2Vec2Model', diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index f5142382b6..42535173ea 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -3,10 +3,7 @@ import torch from torch import Tensor from torch import nn -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url +from torch.hub import load_state_dict_from_url __all__ = [ @@ -15,7 +12,7 @@ "Stretch2d", "UpsampleNetwork", "WaveRNN", - "get_pretrained_wavernn", + "wavernn", ] @@ -351,25 +348,24 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: return x.unsqueeze(1) -def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveRNN: +def wavernn(checkpoint_name: str) -> WaveRNN: r"""Get pretrained WaveRNN model. - Here are the available checkpoints: - - - wavernn_10k_epochs_8bits_ljspeech + Args: + checkpoint_name (str): The name of the checkpoint to load. Available checkpoints: - WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. - The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + - ``"wavernn_10k_epochs_8bits_ljspeech"``: - Args: - checkpoint_name (str): The name of the checkpoint to load. - progress (bool): If True, displays a progress bar of the download to stderr. + WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. + The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. """ if checkpoint_name not in _MODEL_CONFIG_AND_URLS: - raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name)) + raise ValueError( + f"Unexpected checkpoint_name: '{checkpoint_name}'. " + f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}") url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name] model = WaveRNN(**configs) - state_dict = load_state_dict_from_url(url, progress=progress) + state_dict = load_state_dict_from_url(url, progress=False) model.load_state_dict(state_dict) return model From c4c7fa4c6f8c26b9366d995f1697ea92d2bdce8c Mon Sep 17 00:00:00 2001 From: Yao-Yuan Yang Date: Tue, 20 Jul 2021 21:01:52 +0000 Subject: [PATCH 5/5] update docstring --- torchaudio/models/wavernn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py index 42535173ea..5eb10d12cf 100644 --- a/torchaudio/models/wavernn.py +++ b/torchaudio/models/wavernn.py @@ -357,7 +357,9 @@ def wavernn(checkpoint_name: str) -> WaveRNN: - ``"wavernn_10k_epochs_8bits_ljspeech"``: WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset. - The model is trained using the default parameters and code of the examples/pipeline_wavernn/main.py. + The model is trained using the default parameters and code of the + `examples/pipeline_wavernn/main.py + `_. """ if checkpoint_name not in _MODEL_CONFIG_AND_URLS: raise ValueError(