@@ -718,10 +718,12 @@ def tmp(_, inp, out):
718718 for n , p in sub_layer .named_parameters ():
719719 param_name = full_layer_name + "." + n
720720 if n == "weight" :
721- set_module_tensor_to_device (self .model , param_name , self .device , Q )
721+ set_module_tensor_to_device (self .model , param_name , self .device , Q , dtype = Q . dtype )
722722 else :
723723 value = load_value (self .model , param_name , model_path )
724- set_module_tensor_to_device (self .model , param_name , self .device , value )
724+ set_module_tensor_to_device (
725+ self .model , param_name , self .device , value , dtype = value .dtype
726+ )
725727 # sub_layer.weight.data = Q
726728 torch .save (sub_layer .state_dict (), LWQ_WORKSPACE + f"/{ full_layer_name } .pt" )
727729 clean_module_weight (sub_layer )
@@ -745,6 +747,8 @@ def tmp(_, inp, out):
745747 for j in range (len (self .dataloader )):
746748 cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
747749 cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
750+ # breakpoint()
751+ # transformer_block = transformer_block.to(getattr(torch, self.model.config.torch_dtype))
748752 out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )
749753 out = self .track_hidden_states (out )
750754 outs .append (out )
0 commit comments