1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ import inspect
1718import os
1819from functools import partial
1920from typing import Callable , List , Optional , Tuple , Union
@@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
489490 state_dict = load_state_dict (model_file )
490491 # move the parms from meta device to cpu
491492 for param_name , param in state_dict .items ():
492- set_module_tensor_to_device (model , param_name , param_device , value = param )
493+ accepts_dtype = "dtype" in set (inspect .signature (set_module_tensor_to_device ).parameters .keys ())
494+ if accepts_dtype :
495+ set_module_tensor_to_device (model , param_name , param_device , value = param , dtype = torch_dtype )
496+ else :
497+ set_module_tensor_to_device (model , param_name , param_device , value = param )
493498 else : # else let accelerate handle loading and dispatching.
494499 # Load weights and dispatch according to the device_map
495500 # by deafult the device_map is None and the weights are loaded on the CPU
496- accelerate .load_checkpoint_and_dispatch (model , model_file , device_map )
501+ accelerate .load_checkpoint_and_dispatch (model , model_file , device_map , dtype = torch_dtype )
497502
498503 loading_info = {
499504 "missing_keys" : [],
@@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
519524 model = cls .from_config (config , ** unused_kwargs )
520525
521526 state_dict = load_state_dict (model_file )
522- dtype = set (v .dtype for v in state_dict .values ())
523-
524- if len (dtype ) > 1 and torch .float32 not in dtype :
525- raise ValueError (
526- f"The weights of the model file { model_file } have a mixture of incompatible dtypes { dtype } . Please"
527- f" make sure that { model_file } weights have only one dtype."
528- )
529- elif len (dtype ) > 1 and torch .float32 in dtype :
530- dtype = torch .float32
531- else :
532- dtype = dtype .pop ()
533-
534- # move model to correct dtype
535- model = model .to (dtype )
536527
537528 model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
538529 model ,
0 commit comments