16
16
17
17
import logging
18
18
import warnings
19
+ import os
20
+
19
21
from collections import OrderedDict
20
22
21
23
from tensorflow_tts .configs import (
40
42
SavableTFFastSpeech2 ,
41
43
SavableTFTacotron2
42
44
)
45
+ from tensorflow_tts .utils import CACHE_DIRECTORY , MODEL_FILE_NAME , LIBRARY_NAME
46
+ from tensorflow_tts import __version__ as VERSION
47
+ from huggingface_hub import hf_hub_url , cached_download
43
48
44
49
45
50
TF_MODEL_MAPPING = OrderedDict (
@@ -62,8 +67,35 @@ def __init__(self):
62
67
raise EnvironmentError ("Cannot be instantiated using `__init__()`" )
63
68
64
69
@classmethod
65
- def from_pretrained (cls , config , pretrained_path = None , ** kwargs ):
70
+ def from_pretrained (cls , config = None , pretrained_path = None , ** kwargs ):
66
71
is_build = kwargs .pop ("is_build" , True )
72
+
73
+ # load weights from hf hub
74
+ if pretrained_path is not None :
75
+ if not os .path .isfile (pretrained_path ):
76
+ # retrieve correct hub url
77
+ download_url = hf_hub_url (repo_id = pretrained_path , filename = MODEL_FILE_NAME )
78
+
79
+ downloaded_file = str (
80
+ cached_download (
81
+ url = download_url ,
82
+ library_name = LIBRARY_NAME ,
83
+ library_version = VERSION ,
84
+ cache_dir = CACHE_DIRECTORY ,
85
+ )
86
+ )
87
+
88
+ # load config from repo as well
89
+ if config is None :
90
+ from tensorflow_tts .inference import AutoConfig
91
+
92
+ config = AutoConfig .from_pretrained (pretrained_path )
93
+
94
+ pretraine_path = downloaded_file
95
+
96
+
97
+ assert config is not None , "Please make sure to pass a config along to load a model from a local file"
98
+
67
99
for config_class , model_class in TF_MODEL_MAPPING .items ():
68
100
if isinstance (config , config_class ) and str (config_class .__name__ ) in str (
69
101
config
@@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs):
79
111
pretrained_path , by_name = True , skip_mismatch = True
80
112
)
81
113
return model
114
+
82
115
raise ValueError (
83
116
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n "
84
117
"Model type should be one of {}." .format (
0 commit comments