diff --git a/neural_compressor/utils/load_huggingface.py b/neural_compressor/utils/load_huggingface.py index 6671640722d..c1536736c01 100644 --- a/neural_compressor/utils/load_huggingface.py +++ b/neural_compressor/utils/load_huggingface.py @@ -83,7 +83,7 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module: **kwargs, ) - model_class = eval(f"transformers.{config.architectures[0]}") + model_class = transformers.AutoModelForSequenceClassification if config.torch_dtype is not torch.int8: model = model_class.from_pretrained( model_name_or_path, @@ -96,81 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module: ) return model else: - logger.info("the quantization optimized model is loading.") - keys_to_ignore_on_load_unexpected = copy.deepcopy( - getattr(model_class, "_keys_to_ignore_on_load_unexpected", None) - ) - keys_to_ignore_on_load_missing = copy.deepcopy( - getattr(model_class, "_keys_to_ignore_on_load_missing", None) - ) - - # Avoid unnecessary warnings resulting from quantized model initialization - quantized_keys_to_ignore_on_load = [ - r"zero_point", - r"scale", - r"packed_params", - r"constant", - r"module", - r"best_configure", - ] - if keys_to_ignore_on_load_unexpected is None: - model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load - else: - model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load) - missing_keys_to_ignore_on_load = [r"weight", r"bias"] - if keys_to_ignore_on_load_missing is None: - model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load - else: # pragma: no cover - model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load) - - if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover - from transformers.utils import cached_file - - try: - # Load from URL or cache if already cached - resolved_weights_file = cached_file( - model_name_or_path, - filename=WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - use_auth_token=use_auth_token, - ) - except EnvironmentError as err: # pragma: no cover - logger.error(err) - msg = ( - f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n" - f"- '{model_name_or_path}' is a correct model identifier " - f"listed on 'https://huggingface.co/models'\n (make sure " - f"'{model_name_or_path}' is not a path to a local directory with " - f"something else, in that case)\n\n- or '{model_name_or_path}' is " - f"the correct path to a directory containing a file " - f"named one of {WEIGHTS_NAME}\n\n" - ) - if revision is not None: - msg += ( - f"- or '{revision}' is a valid git identifier " - f"(branch name, a tag name, or a commit id) that " - f"exists for this model name as listed on its model " - f"page on 'https://huggingface.co/models'\n\n" - ) - raise EnvironmentError(msg) - else: - resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(resolved_weights_file, {}) - model = model_class.from_pretrained( - model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - use_auth_token=use_auth_token, - revision=revision, - state_dict=state_dict, - **kwargs, - ) - - model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected - model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing + config.torch_dtype = torch.float32 + model = model_class.from_config(config) if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover # pylint: disable=E0611