Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 3 additions & 76 deletions neural_compressor/utils/load_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
**kwargs,
)

model_class = eval(f"transformers.{config.architectures[0]}")
model_class = transformers.AutoModelForSequenceClassification
if config.torch_dtype is not torch.int8:
model = model_class.from_pretrained(
model_name_or_path,
Expand All @@ -96,81 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
)
return model
else:
logger.info("the quantization optimized model is loading.")
keys_to_ignore_on_load_unexpected = copy.deepcopy(
getattr(model_class, "_keys_to_ignore_on_load_unexpected", None)
)
keys_to_ignore_on_load_missing = copy.deepcopy(
getattr(model_class, "_keys_to_ignore_on_load_missing", None)
)

# Avoid unnecessary warnings resulting from quantized model initialization
quantized_keys_to_ignore_on_load = [
r"zero_point",
r"scale",
r"packed_params",
r"constant",
r"module",
r"best_configure",
]
if keys_to_ignore_on_load_unexpected is None:
model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
else:
model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load)
missing_keys_to_ignore_on_load = [r"weight", r"bias"]
if keys_to_ignore_on_load_missing is None:
model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
else: # pragma: no cover
model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load)

if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
from transformers.utils import cached_file

try:
# Load from URL or cache if already cached
resolved_weights_file = cached_file(
model_name_or_path,
filename=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
use_auth_token=use_auth_token,
)
except EnvironmentError as err: # pragma: no cover
logger.error(err)
msg = (
f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n"
f"- '{model_name_or_path}' is a correct model identifier "
f"listed on 'https://huggingface.co/models'\n (make sure "
f"'{model_name_or_path}' is not a path to a local directory with "
f"something else, in that case)\n\n- or '{model_name_or_path}' is "
f"the correct path to a directory containing a file "
f"named one of {WEIGHTS_NAME}\n\n"
)
if revision is not None:
msg += (
f"- or '{revision}' is a valid git identifier "
f"(branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model "
f"page on 'https://huggingface.co/models'\n\n"
)
raise EnvironmentError(msg)
else:
resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(resolved_weights_file, {})
model = model_class.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
use_auth_token=use_auth_token,
revision=revision,
state_dict=state_dict,
**kwargs,
)

model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing
config.torch_dtype = torch.float32
model = model_class.from_config(config)

if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
# pylint: disable=E0611
Expand Down
Loading