@@ -83,7 +83,7 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
8383 ** kwargs ,
8484 )
8585
86- model_class = eval ( f" transformers.{ config . architectures [ 0 ] } " )
86+ model_class = transformers .AutoModelForSequenceClassification
8787 if config .torch_dtype is not torch .int8 :
8888 model = model_class .from_pretrained (
8989 model_name_or_path ,
@@ -96,81 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
9696 )
9797 return model
9898 else :
99- logger .info ("the quantization optimized model is loading." )
100- keys_to_ignore_on_load_unexpected = copy .deepcopy (
101- getattr (model_class , "_keys_to_ignore_on_load_unexpected" , None )
102- )
103- keys_to_ignore_on_load_missing = copy .deepcopy (
104- getattr (model_class , "_keys_to_ignore_on_load_missing" , None )
105- )
106-
107- # Avoid unnecessary warnings resulting from quantized model initialization
108- quantized_keys_to_ignore_on_load = [
109- r"zero_point" ,
110- r"scale" ,
111- r"packed_params" ,
112- r"constant" ,
113- r"module" ,
114- r"best_configure" ,
115- ]
116- if keys_to_ignore_on_load_unexpected is None :
117- model_class ._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
118- else :
119- model_class ._keys_to_ignore_on_load_unexpected .extend (quantized_keys_to_ignore_on_load )
120- missing_keys_to_ignore_on_load = [r"weight" , r"bias" ]
121- if keys_to_ignore_on_load_missing is None :
122- model_class ._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
123- else : # pragma: no cover
124- model_class ._keys_to_ignore_on_load_missing .extend (missing_keys_to_ignore_on_load )
125-
126- if not os .path .isdir (model_name_or_path ) and not os .path .isfile (model_name_or_path ): # pragma: no cover
127- from transformers .utils import cached_file
128-
129- try :
130- # Load from URL or cache if already cached
131- resolved_weights_file = cached_file (
132- model_name_or_path ,
133- filename = WEIGHTS_NAME ,
134- cache_dir = cache_dir ,
135- force_download = force_download ,
136- resume_download = resume_download ,
137- use_auth_token = use_auth_token ,
138- )
139- except EnvironmentError as err : # pragma: no cover
140- logger .error (err )
141- msg = (
142- f"Can't load weights for '{ model_name_or_path } '. Make sure that:\n \n "
143- f"- '{ model_name_or_path } ' is a correct model identifier "
144- f"listed on 'https://huggingface.co/models'\n (make sure "
145- f"'{ model_name_or_path } ' is not a path to a local directory with "
146- f"something else, in that case)\n \n - or '{ model_name_or_path } ' is "
147- f"the correct path to a directory containing a file "
148- f"named one of { WEIGHTS_NAME } \n \n "
149- )
150- if revision is not None :
151- msg += (
152- f"- or '{ revision } ' is a valid git identifier "
153- f"(branch name, a tag name, or a commit id) that "
154- f"exists for this model name as listed on its model "
155- f"page on 'https://huggingface.co/models'\n \n "
156- )
157- raise EnvironmentError (msg )
158- else :
159- resolved_weights_file = os .path .join (model_name_or_path , WEIGHTS_NAME )
160- state_dict = torch .load (resolved_weights_file , {})
161- model = model_class .from_pretrained (
162- model_name_or_path ,
163- cache_dir = cache_dir ,
164- force_download = force_download ,
165- resume_download = resume_download ,
166- use_auth_token = use_auth_token ,
167- revision = revision ,
168- state_dict = state_dict ,
169- ** kwargs ,
170- )
171-
172- model_class ._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
173- model_class ._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing
99+ config .torch_dtype = torch .float32
100+ model = model_class .from_config (config )
174101
175102 if not os .path .isdir (model_name_or_path ) and not os .path .isfile (model_name_or_path ): # pragma: no cover
176103 # pylint: disable=E0611
0 commit comments