6060
6161
6262##################################### save ##################################
63-
6463def remove_rank_suffix (name , local_rank , world_size ):
6564 """Remove rank suffix from key name."""
6665 return name .removesuffix (f"_{ local_rank } _{ world_size } " )
@@ -291,54 +290,6 @@ def convert_config_to_vllm_compatible(config):
291290 }
292291 return quantization_config
293292
294- def gather_and_save_model (tp_model , original_model , world_size , max_shard_size , output_dir ):
295- """
296- Merge the weights of a DeepSpeed Tensor Parallel model into the original model
297- and save them as sharded safetensors files.
298-
299- Args:
300- tp_model: DeepSpeed tensor parallel model.
301- original_model: The original unsharded model.
302- world_size: The model parallelism size.
303- max_shard_size: The maximum size of each shard in bytes.
304- output_dir: The directory to save the files.
305- """
306- import deepspeed
307- merged_state_dict = {}
308-
309- for name , tp_param in tp_model .named_parameters ():
310- if "scale" not in name :
311- param = original_model .state_dict ()[name ].t ()
312- if len (param .shape ) != 0 and tp_param .shape != param .shape :
313- # Perform all-gather to merge tensor parallel parameters
314- tensor_list = [torch .zeros_like (tp_param ) for _ in range (world_size )]
315- deepspeed .comm .all_gather (tensor_list , tp_param )
316- if local_rank == 0 :
317- if tp_param .shape [0 ] != param .shape [0 ]:
318- merged_param = torch .cat (tensor_list , dim = 0 )
319- elif tp_param .shape [1 ] != param .shape [1 ]:
320- merged_param = torch .cat (tensor_list , dim = 1 )
321- logger .info (f"[Rank { local_rank } ] Merging parameter: { name } " )
322- else :
323- if local_rank == 0 :
324- # No merging needed if shapes match
325- merged_param = tp_param
326- logger .info (f"[Rank { local_rank } ] No merging needed for parameter: { name } " )
327- else :
328- deepspeed .comm .all_reduce (tp_param , op = deepspeed .comm .ReduceOp .MAX )
329- if local_rank == 0 :
330- merged_param = tp_param
331- logger .info (f"[Rank { local_rank } ] No merging needed for parameter: { name } " )
332- # Update the merged state dict
333- if local_rank == 0 :
334- merged_state_dict [name ] = merged_param
335-
336- # Save the merged state dict as sharded safetensors files (only on rank 0)
337- if local_rank == 0 :
338- logger .info (f"[Rank { local_rank } ] Saving model shards to { output_dir } ..." )
339- merged_state_dict = convert_weight_to_vllm_compatible (state_dict = merged_state_dict )
340- save_state_dict_sharded_safetensors (merged_state_dict , output_dir , f"{ MAX_FILE_SIZE } GB" )
341-
342293
343294def check_config_for_vllm_compatible (config ):
344295 """
@@ -362,37 +313,30 @@ def save_for_multi_devices(model, checkpoint_dir="saved_results", format="huggin
362313 format (str, optional): defaults to 'huggingface'.
363314 """
364315 from safetensors .torch import save_file as safe_save_file
365- if format == SaveLoadFormat .VLLM :
366- import transformers
367- from accelerate import init_empty_weights
368- with init_empty_weights (include_buffers = False ):
369- reference_model = transformers .AutoModelForCausalLM .from_config (model .config )
370- gather_and_save_model (
371- model , reference_model , world_size = world_size , max_shard_size = f"{ MAX_FILE_SIZE } GB" , output_dir = checkpoint_dir
316+ folder_prefix = os .path .join (options .workspace , checkpoint_dir )
317+ save_rank_model (model , folder_prefix = folder_prefix , ** kwargs )
318+ # Ensure all ranks have saved their model before proceeding
319+ if torch .distributed .is_initialized ():
320+ torch .distributed .barrier ()
321+ rank_directory = add_rank_suffix (folder_prefix , 0 , world_size )
322+ files_list = find_safetensors_files (rank_directory )
323+ # use rank:0 process to gather checkpoint files
324+ if local_rank == 0 :
325+ tp_mod_list = find_tp_mod_list (model )
326+ os .makedirs (checkpoint_dir , exist_ok = True )
327+ # get the safetensors file list from one folder
328+ # based on the safetensors file name to collect tensors from shard folders
329+ for file_name in files_list :
330+ gathered_state_dict = gather_state_dict (
331+ folder_prefix = folder_prefix , file_name = file_name , tp_mod_list = tp_mod_list
372332 )
373- else :
374- folder_prefix = os .path .join (options .workspace , checkpoint_dir )
375- save_rank_model (model , folder_prefix = folder_prefix , ** kwargs )
376- # Ensure all ranks have saved their model before proceeding
377- if torch .distributed .is_initialized ():
378- torch .distributed .barrier ()
379- rank_directory = add_rank_suffix (folder_prefix , 0 , world_size )
380- files_list = find_safetensors_files (rank_directory )
381- # use rank:0 process to gather checkpoint files
382- if local_rank == 0 :
383- tp_mod_list = find_tp_mod_list (model )
384- os .makedirs (checkpoint_dir , exist_ok = True )
385- # get the safetensors file list from one folder
386- # based on the safetensors file name to collect tensors from shard folders
387- for file_name in files_list :
388- gathered_state_dict = gather_state_dict (
389- folder_prefix = folder_prefix , file_name = file_name , tp_mod_list = tp_mod_list
390- )
391- safe_save_file (gathered_state_dict , os .path .join (checkpoint_dir , file_name ), metadata = {"format" : "pt" })
392- clean_rank_files (folder_prefix = folder_prefix , file_name = file_name )
393- if torch .distributed .is_initialized ():
394- torch .distributed .barrier ()
395- clean_rank_files (folder_prefix = folder_prefix )
333+ if format == SaveLoadFormat .VLLM :
334+ gathered_state_dict = update_to_vllm_compatible (model , gathered_state_dict )
335+ safe_save_file (gathered_state_dict , os .path .join (checkpoint_dir , file_name ), metadata = {"format" : "pt" })
336+ clean_rank_files (folder_prefix = folder_prefix , file_name = file_name )
337+ if torch .distributed .is_initialized ():
338+ torch .distributed .barrier ()
339+ clean_rank_files (folder_prefix = folder_prefix )
396340
397341
398342def save_for_single_device (model , checkpoint_dir = "saved_results" , format = "huggingface" , ** kwargs ):
@@ -558,6 +502,9 @@ def shard_state_dict(state_dict, return_all_rank=False, src_world_size=None):
558502 def get_rank_state_dict (state_dict , local_rank , world_size ):
559503 rank_state_dict = {}
560504 for k , v in state_dict .items ():
505+ if k .endswith (f"_{ world_size } " ) and not k .endswith (f"_{ local_rank } _{ world_size } " ):
506+ # only collect current rank state_dict and common state_dict (e.g., embedding)
507+ continue
561508 new_k = remove_rank_suffix (k , local_rank , world_size )
562509 rank_state_dict [new_k ] = v
563510 return rank_state_dict
@@ -621,7 +568,11 @@ def get_inc_fp8config(model, from_neuralmagic=False, from_neuralmagic_with_kv=Fa
621568 allowlist = {"types" : ["Linear" , "LinearLayer" , "LinearAllreduce" ], "names" : []}
622569 qconfig = FP8Config (mode = "LOAD" , allowlist = allowlist , blocklist = blocklist , scale_format = "CONST" )
623570 else :
624- qconfig = FP8Config .from_dict (model .config .quantization_config )
571+ if hasattr (model , "qconfig" ) and model .qconfig is not None :
572+ configs_mapping = model .qconfig
573+ qconfig = configs_mapping [next (iter (configs_mapping ))]
574+ else :
575+ qconfig = FP8Config .from_dict (model .config .quantization_config )
625576 return qconfig
626577
627578
@@ -678,7 +629,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
678629 raise EnvironmentError (
679630 f"Original world_size: { src_world_size } must be divisible by target world_size: { world_size } ."
680631 )
681- rank_state_dict = get_new_rank_state_dict (rank_state_dict , world_size , model = model )
632+ rank_state_dict = get_new_rank_state_dict (rank_state_dict , model , world_size )
682633 else :
683634 rank_state_dict = gathered_state_dict
684635 model .load_state_dict (rank_state_dict , assign = True , strict = False )
@@ -816,8 +767,16 @@ def load_scale_params(model, new_scale_params):
816767 param .data = new_scale
817768
818769
819- def get_new_rank_state_dict (all_rank_state_dict , world_size = world_size , local_rank = local_rank , model = None ):
820- """Get new rank state_dict for world_size."""
770+ def get_new_rank_state_dict (all_rank_state_dict , model , world_size = world_size , local_rank = local_rank ):
771+ """Get new rank state_dict for world_size.
772+
773+ Args:
774+ all_rank_state_dict (dict): {0: state_dict, 1: state_dict, ...} for all ranks.
775+ model (torch.nn.Module): A quantized model with empty weights in load mode.
776+ world_size (int, optional): Target world size. Defaults to world_size.
777+ local_rank (int, optional): Current local rank. Defaults to local_rank.
778+ """
779+
821780 def dq_q_weight (weight , previous_scale , target_scale ):
822781 """dequantize and quantize weight with different scales."""
823782 cast_to_op = get_quantized_func_wrapper (OP_TYPE .CAST_TO_FP8 , ScaleFormat .CONST )
@@ -936,3 +895,34 @@ def update_corresponding_weight(previous_state_dict, cur_state_dict, target_weig
936895 tp_dim = 0 if params_dict [k ].shape [0 ] != new_rank_state_dict [k ].shape [0 ] else 1
937896 new_rank_state_dict [k ] = torch .cat ([new_rank_state_dict [k ], v ], dim = tp_dim )
938897 return new_rank_state_dict
898+
899+
900+ def update_to_vllm_compatible (model , gathered_state_dict ):
901+ """update gathered_state_dict to vllm compatible format.
902+
903+ Args:
904+ model (model): the quantized model.
905+ gathered_state_dict (dict): state_dict for all ranks.
906+ Returns:
907+ gathered_state_dict (dict): vllm compatible state_dict.
908+ """
909+ # for example, tp = 2, make
910+ # 'model.layers.21.mlp.down_proj.weight_0_2', 'model.layers.21.mlp.down_proj.weight_1_2'
911+ # --> 'model.layers.21.mlp.down_proj.weight'
912+ import transformers
913+ from accelerate import init_empty_weights
914+ from neural_compressor .torch .algorithms .fp8_quant import prep_model
915+
916+ with init_empty_weights (include_buffers = False ):
917+ reference_model = transformers .AutoModelForCausalLM .from_config (model .config )
918+ # replace modules with patched modules
919+ inc_config = get_inc_fp8config (model )
920+ inc_config .mode = "LOAD"
921+ inc_config .save_temp_json_file ()
922+ prep_model (reference_model , inc_config .json_file )
923+ # gather weights into 1 rank
924+ rank_state_dict = shard_state_dict (gathered_state_dict , return_all_rank = True , src_world_size = world_size )
925+ rank_state_dict = get_new_rank_state_dict (rank_state_dict , reference_model , world_size = 1 )
926+ # rename param names
927+ gathered_state_dict = convert_weight_to_vllm_compatible (state_dict = rank_state_dict )
928+ return gathered_state_dict
0 commit comments