|
17 | 17 | ] |
18 | 18 |
|
19 | 19 |
|
20 | | -model_config_and_urls: Dict[str, Tuple[str, Dict[str, Any]]] = { |
| 20 | +_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = { |
21 | 21 | 'wavernn_10k_epochs_8bits_ljspeech': ( |
22 | 22 | 'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth', |
23 | 23 | { |
@@ -363,11 +363,11 @@ def get_pretrained_wavernn(checkpoint_name: str, progress: bool = True) -> WaveR |
363 | 363 | checkpoint_name (str): The name of the checkpoint to load. |
364 | 364 | progress (bool): If True, displays a progress bar of the download to stderr. |
365 | 365 | """ |
366 | | - if checkpoint_name in model_config_and_urls: |
367 | | - url, configs = model_config_and_urls[checkpoint_name] |
368 | | - model = WaveRNN(**configs) |
369 | | - state_dict = load_state_dict_from_url(url, progress=progress) |
370 | | - model.load_state_dict(state_dict) |
371 | | - return model |
372 | | - else: |
373 | | - raise ValueError("The model_name `{}` is not supported.".format(checkpoint_name)) |
| 366 | + if checkpoint_name not in _MODEL_CONFIG_AND_URLS: |
| 367 | + raise ValueError("The checkpoint_name `{}` is not supported.".format(checkpoint_name)) |
| 368 | + |
| 369 | + url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name] |
| 370 | + model = WaveRNN(**configs) |
| 371 | + state_dict = load_state_dict_from_url(url, progress=progress) |
| 372 | + model.load_state_dict(state_dict) |
| 373 | + return model |
0 commit comments