Skip to content

Commit 8b035eb

Browse files
authored
fix vulnerability (#2149)
Signed-off-by: xin3he <[email protected]>
1 parent 22ffcae commit 8b035eb

File tree

1 file changed

+3
-76
lines changed

1 file changed

+3
-76
lines changed

neural_compressor/utils/load_huggingface.py

Lines changed: 3 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
8383
**kwargs,
8484
)
8585

86-
model_class = eval(f"transformers.{config.architectures[0]}")
86+
model_class = transformers.AutoModelForSequenceClassification
8787
if config.torch_dtype is not torch.int8:
8888
model = model_class.from_pretrained(
8989
model_name_or_path,
@@ -96,81 +96,8 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
9696
)
9797
return model
9898
else:
99-
logger.info("the quantization optimized model is loading.")
100-
keys_to_ignore_on_load_unexpected = copy.deepcopy(
101-
getattr(model_class, "_keys_to_ignore_on_load_unexpected", None)
102-
)
103-
keys_to_ignore_on_load_missing = copy.deepcopy(
104-
getattr(model_class, "_keys_to_ignore_on_load_missing", None)
105-
)
106-
107-
# Avoid unnecessary warnings resulting from quantized model initialization
108-
quantized_keys_to_ignore_on_load = [
109-
r"zero_point",
110-
r"scale",
111-
r"packed_params",
112-
r"constant",
113-
r"module",
114-
r"best_configure",
115-
]
116-
if keys_to_ignore_on_load_unexpected is None:
117-
model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
118-
else:
119-
model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load)
120-
missing_keys_to_ignore_on_load = [r"weight", r"bias"]
121-
if keys_to_ignore_on_load_missing is None:
122-
model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
123-
else: # pragma: no cover
124-
model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load)
125-
126-
if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
127-
from transformers.utils import cached_file
128-
129-
try:
130-
# Load from URL or cache if already cached
131-
resolved_weights_file = cached_file(
132-
model_name_or_path,
133-
filename=WEIGHTS_NAME,
134-
cache_dir=cache_dir,
135-
force_download=force_download,
136-
resume_download=resume_download,
137-
use_auth_token=use_auth_token,
138-
)
139-
except EnvironmentError as err: # pragma: no cover
140-
logger.error(err)
141-
msg = (
142-
f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n"
143-
f"- '{model_name_or_path}' is a correct model identifier "
144-
f"listed on 'https://huggingface.co/models'\n (make sure "
145-
f"'{model_name_or_path}' is not a path to a local directory with "
146-
f"something else, in that case)\n\n- or '{model_name_or_path}' is "
147-
f"the correct path to a directory containing a file "
148-
f"named one of {WEIGHTS_NAME}\n\n"
149-
)
150-
if revision is not None:
151-
msg += (
152-
f"- or '{revision}' is a valid git identifier "
153-
f"(branch name, a tag name, or a commit id) that "
154-
f"exists for this model name as listed on its model "
155-
f"page on 'https://huggingface.co/models'\n\n"
156-
)
157-
raise EnvironmentError(msg)
158-
else:
159-
resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME)
160-
state_dict = torch.load(resolved_weights_file, {})
161-
model = model_class.from_pretrained(
162-
model_name_or_path,
163-
cache_dir=cache_dir,
164-
force_download=force_download,
165-
resume_download=resume_download,
166-
use_auth_token=use_auth_token,
167-
revision=revision,
168-
state_dict=state_dict,
169-
**kwargs,
170-
)
171-
172-
model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
173-
model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing
99+
config.torch_dtype = torch.float32
100+
model = model_class.from_config(config)
174101

175102
if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
176103
# pylint: disable=E0611

0 commit comments

Comments
 (0)