Skip to content

Commit b6eca86

Browse files
committed
Fix a few coding style
1 parent 330e329 commit b6eca86

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torchaudio/models/wavernn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
]
1818

1919

20-
model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = {
20+
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
2121
'wavernn_10k_epochs_8bits_ljspeech': (
2222
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
2323
{
@@ -363,11 +363,11 @@ def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveR
363363
checkpoint_name (str): The name of the checkpoint to load.
364364
progress (bool): If True, displays a progress bar of the download to stderr.
365365
"""
366-
if checkpoint_name in model_config_and_urls:
367-
url, configs = model_config_and_urls[checkpoint_name]
368-
model = WaveRNN(**configs)
369-
state_dict = load_state_dict_from_url(url, progress=progress)
370-
model.load_state_dict(state_dict)
371-
return model
372-
else:
373-
raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name))
366+
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
367+
raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name))
368+
369+
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
370+
model = WaveRNN(**configs)
371+
state_dict = load_state_dict_from_url(url, progress=progress)
372+
model.load_state_dict(state_dict)
373+
return model

0 commit comments

Comments
 (0)