Skip to content

Commit f21b4d2

Browse files
[Dtype] Align dtype casting behavior with Transformers and Accelerate (huggingface#1725)
* [Dtype] Align automatic dtype * up * up * fix * re-add accelerate
1 parent fbbecfa commit f21b4d2

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

modeling_utils.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import inspect
1718
import os
1819
from functools import partial
1920
from typing import Callable, List, Optional, Tuple, Union
@@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
489490
state_dict = load_state_dict(model_file)
490491
# move the parms from meta device to cpu
491492
for param_name, param in state_dict.items():
492-
set_module_tensor_to_device(model, param_name, param_device, value=param)
493+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
494+
if accepts_dtype:
495+
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
496+
else:
497+
set_module_tensor_to_device(model, param_name, param_device, value=param)
493498
else: # else let accelerate handle loading and dispatching.
494499
# Load weights and dispatch according to the device_map
495500
# by deafult the device_map is None and the weights are loaded on the CPU
496-
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
501+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
497502

498503
loading_info = {
499504
"missing_keys": [],
@@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
519524
model = cls.from_config(config, **unused_kwargs)
520525

521526
state_dict = load_state_dict(model_file)
522-
dtype = set(v.dtype for v in state_dict.values())
523-
524-
if len(dtype) > 1 and torch.float32 not in dtype:
525-
raise ValueError(
526-
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
527-
f" make sure that {model_file} weights have only one dtype."
528-
)
529-
elif len(dtype) > 1 and torch.float32 in dtype:
530-
dtype = torch.float32
531-
else:
532-
dtype = dtype.pop()
533-
534-
# move model to correct dtype
535-
model = model.to(dtype)
536527

537528
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
538529
model,

0 commit comments

Comments
 (0)