Skip to content

Commit 283de79

Browse files
changwangssxinhe3
andauthored
[SW-219751]improve vllm compatible save function (#217)
* improve vllm compatible save to avoid OOM --------- Signed-off-by: changwangss <[email protected]> Signed-off-by: Xin He <[email protected]> Co-authored-by: Xin He <[email protected]>
1 parent a7eaaa9 commit 283de79

File tree

2 files changed

+76
-84
lines changed

2 files changed

+76
-84
lines changed

neural_compressor/torch/algorithms/fp8_quant/save_load.py

Lines changed: 73 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060

6161

6262
##################################### save ##################################
63-
6463
def remove_rank_suffix(name, local_rank, world_size):
6564
"""Remove rank suffix from key name."""
6665
return name.removesuffix(f"_{local_rank}_{world_size}")
@@ -291,54 +290,6 @@ def convert_config_to_vllm_compatible(config):
291290
}
292291
return quantization_config
293292

294-
def gather_and_save_model(tp_model, original_model, world_size, max_shard_size, output_dir):
295-
"""
296-
Merge the weights of a DeepSpeed Tensor Parallel model into the original model
297-
and save them as sharded safetensors files.
298-
299-
Args:
300-
tp_model: DeepSpeed tensor parallel model.
301-
original_model: The original unsharded model.
302-
world_size: The model parallelism size.
303-
max_shard_size: The maximum size of each shard in bytes.
304-
output_dir: The directory to save the files.
305-
"""
306-
import deepspeed
307-
merged_state_dict = {}
308-
309-
for name, tp_param in tp_model.named_parameters():
310-
if "scale" not in name:
311-
param = original_model.state_dict()[name].t()
312-
if len(param.shape) != 0 and tp_param.shape != param.shape:
313-
# Perform all-gather to merge tensor parallel parameters
314-
tensor_list = [torch.zeros_like(tp_param) for _ in range(world_size)]
315-
deepspeed.comm.all_gather(tensor_list, tp_param)
316-
if local_rank == 0:
317-
if tp_param.shape[0] != param.shape[0]:
318-
merged_param = torch.cat(tensor_list, dim=0)
319-
elif tp_param.shape[1] != param.shape[1]:
320-
merged_param = torch.cat(tensor_list, dim=1)
321-
logger.info(f"[Rank {local_rank}] Merging parameter: {name}")
322-
else:
323-
if local_rank == 0:
324-
# No merging needed if shapes match
325-
merged_param = tp_param
326-
logger.info(f"[Rank {local_rank}] No merging needed for parameter: {name}")
327-
else:
328-
deepspeed.comm.all_reduce(tp_param, op=deepspeed.comm.ReduceOp.MAX)
329-
if local_rank == 0:
330-
merged_param = tp_param
331-
logger.info(f"[Rank {local_rank}] No merging needed for parameter: {name}")
332-
# Update the merged state dict
333-
if local_rank == 0:
334-
merged_state_dict[name] = merged_param
335-
336-
# Save the merged state dict as sharded safetensors files (only on rank 0)
337-
if local_rank == 0:
338-
logger.info(f"[Rank {local_rank}] Saving model shards to {output_dir}...")
339-
merged_state_dict = convert_weight_to_vllm_compatible(state_dict=merged_state_dict)
340-
save_state_dict_sharded_safetensors(merged_state_dict, output_dir, f"{MAX_FILE_SIZE}GB")
341-
342293

343294
def check_config_for_vllm_compatible(config):
344295
"""
@@ -362,37 +313,30 @@ def save_for_multi_devices(model, checkpoint_dir="saved_results", format="huggin
362313
format (str, optional): defaults to 'huggingface'.
363314
"""
364315
from safetensors.torch import save_file as safe_save_file
365-
if format == SaveLoadFormat.VLLM:
366-
import transformers
367-
from accelerate import init_empty_weights
368-
with init_empty_weights(include_buffers=False):
369-
reference_model = transformers.AutoModelForCausalLM.from_config(model.config)
370-
gather_and_save_model(
371-
model, reference_model, world_size=world_size, max_shard_size=f"{MAX_FILE_SIZE}GB", output_dir=checkpoint_dir
316+
folder_prefix = os.path.join(options.workspace, checkpoint_dir)
317+
save_rank_model(model, folder_prefix=folder_prefix, **kwargs)
318+
# Ensure all ranks have saved their model before proceeding
319+
if torch.distributed.is_initialized():
320+
torch.distributed.barrier()
321+
rank_directory = add_rank_suffix(folder_prefix, 0, world_size)
322+
files_list = find_safetensors_files(rank_directory)
323+
# use rank:0 process to gather checkpoint files
324+
if local_rank == 0:
325+
tp_mod_list = find_tp_mod_list(model)
326+
os.makedirs(checkpoint_dir, exist_ok=True)
327+
# get the safetensors file list from one folder
328+
# based on the safetensors file name to collect tensors from shard folders
329+
for file_name in files_list:
330+
gathered_state_dict = gather_state_dict(
331+
folder_prefix=folder_prefix, file_name=file_name, tp_mod_list=tp_mod_list
372332
)
373-
else:
374-
folder_prefix = os.path.join(options.workspace, checkpoint_dir)
375-
save_rank_model(model, folder_prefix=folder_prefix, **kwargs)
376-
# Ensure all ranks have saved their model before proceeding
377-
if torch.distributed.is_initialized():
378-
torch.distributed.barrier()
379-
rank_directory = add_rank_suffix(folder_prefix, 0, world_size)
380-
files_list = find_safetensors_files(rank_directory)
381-
# use rank:0 process to gather checkpoint files
382-
if local_rank == 0:
383-
tp_mod_list = find_tp_mod_list(model)
384-
os.makedirs(checkpoint_dir, exist_ok=True)
385-
# get the safetensors file list from one folder
386-
# based on the safetensors file name to collect tensors from shard folders
387-
for file_name in files_list:
388-
gathered_state_dict = gather_state_dict(
389-
folder_prefix=folder_prefix, file_name=file_name, tp_mod_list=tp_mod_list
390-
)
391-
safe_save_file(gathered_state_dict, os.path.join(checkpoint_dir, file_name), metadata={"format": "pt"})
392-
clean_rank_files(folder_prefix=folder_prefix, file_name=file_name)
393-
if torch.distributed.is_initialized():
394-
torch.distributed.barrier()
395-
clean_rank_files(folder_prefix=folder_prefix)
333+
if format == SaveLoadFormat.VLLM:
334+
gathered_state_dict = update_to_vllm_compatible(model, gathered_state_dict)
335+
safe_save_file(gathered_state_dict, os.path.join(checkpoint_dir, file_name), metadata={"format": "pt"})
336+
clean_rank_files(folder_prefix=folder_prefix, file_name=file_name)
337+
if torch.distributed.is_initialized():
338+
torch.distributed.barrier()
339+
clean_rank_files(folder_prefix=folder_prefix)
396340

397341

398342
def save_for_single_device(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
@@ -558,6 +502,9 @@ def shard_state_dict(state_dict, return_all_rank=False, src_world_size=None):
558502
def get_rank_state_dict(state_dict, local_rank, world_size):
559503
rank_state_dict = {}
560504
for k, v in state_dict.items():
505+
if k.endswith(f"_{world_size}") and not k.endswith(f"_{local_rank}_{world_size}"):
506+
# only collect current rank state_dict and common state_dict (e.g., embedding)
507+
continue
561508
new_k = remove_rank_suffix(k, local_rank, world_size)
562509
rank_state_dict[new_k] = v
563510
return rank_state_dict
@@ -621,7 +568,11 @@ def get_inc_fp8config(model, from_neuralmagic=False, from_neuralmagic_with_kv=Fa
621568
allowlist = {"types": ["Linear", "LinearLayer", "LinearAllreduce"], "names": []}
622569
qconfig = FP8Config(mode="LOAD", allowlist=allowlist, blocklist=blocklist, scale_format="CONST")
623570
else:
624-
qconfig = FP8Config.from_dict(model.config.quantization_config)
571+
if hasattr(model, "qconfig") and model.qconfig is not None:
572+
configs_mapping = model.qconfig
573+
qconfig = configs_mapping[next(iter(configs_mapping))]
574+
else:
575+
qconfig = FP8Config.from_dict(model.config.quantization_config)
625576
return qconfig
626577

627578

@@ -678,7 +629,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
678629
raise EnvironmentError(
679630
f"Original world_size: {src_world_size} must be divisible by target world_size: {world_size}."
680631
)
681-
rank_state_dict = get_new_rank_state_dict(rank_state_dict, world_size, model=model)
632+
rank_state_dict = get_new_rank_state_dict(rank_state_dict, model, world_size)
682633
else:
683634
rank_state_dict = gathered_state_dict
684635
model.load_state_dict(rank_state_dict, assign=True, strict=False)
@@ -816,8 +767,16 @@ def load_scale_params(model, new_scale_params):
816767
param.data = new_scale
817768

818769

819-
def get_new_rank_state_dict(all_rank_state_dict, world_size=world_size, local_rank=local_rank, model=None):
820-
"""Get new rank state_dict for world_size."""
770+
def get_new_rank_state_dict(all_rank_state_dict, model, world_size=world_size, local_rank=local_rank):
771+
"""Get new rank state_dict for world_size.
772+
773+
Args:
774+
all_rank_state_dict (dict): {0: state_dict, 1: state_dict, ...} for all ranks.
775+
model (torch.nn.Module): A quantized model with empty weights in load mode.
776+
world_size (int, optional): Target world size. Defaults to world_size.
777+
local_rank (int, optional): Current local rank. Defaults to local_rank.
778+
"""
779+
821780
def dq_q_weight(weight, previous_scale, target_scale):
822781
"""dequantize and quantize weight with different scales."""
823782
cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, ScaleFormat.CONST)
@@ -936,3 +895,34 @@ def update_corresponding_weight(previous_state_dict, cur_state_dict, target_weig
936895
tp_dim = 0 if params_dict[k].shape[0] != new_rank_state_dict[k].shape[0] else 1
937896
new_rank_state_dict[k] = torch.cat([new_rank_state_dict[k], v], dim=tp_dim)
938897
return new_rank_state_dict
898+
899+
900+
def update_to_vllm_compatible(model, gathered_state_dict):
901+
"""update gathered_state_dict to vllm compatible format.
902+
903+
Args:
904+
model (model): the quantized model.
905+
gathered_state_dict (dict): state_dict for all ranks.
906+
Returns:
907+
gathered_state_dict (dict): vllm compatible state_dict.
908+
"""
909+
# for example, tp = 2, make
910+
# 'model.layers.21.mlp.down_proj.weight_0_2', 'model.layers.21.mlp.down_proj.weight_1_2'
911+
# --> 'model.layers.21.mlp.down_proj.weight'
912+
import transformers
913+
from accelerate import init_empty_weights
914+
from neural_compressor.torch.algorithms.fp8_quant import prep_model
915+
916+
with init_empty_weights(include_buffers=False):
917+
reference_model = transformers.AutoModelForCausalLM.from_config(model.config)
918+
# replace modules with patched modules
919+
inc_config = get_inc_fp8config(model)
920+
inc_config.mode = "LOAD"
921+
inc_config.save_temp_json_file()
922+
prep_model(reference_model, inc_config.json_file)
923+
# gather weights into 1 rank
924+
rank_state_dict = shard_state_dict(gathered_state_dict, return_all_rank=True, src_world_size=world_size)
925+
rank_state_dict = get_new_rank_state_dict(rank_state_dict, reference_model, world_size=1)
926+
# rename param names
927+
gathered_state_dict = convert_weight_to_vllm_compatible(state_dict=rank_state_dict)
928+
return gathered_state_dict

test/3x/torch/quantization/fp8_quant/test_save_load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def compare_parameters_buffers(model1, model2, atol=1e-8):
3333
for k, v in dict1.items():
3434
assert k in dict2, "k not in dict2"
3535
assert v.dtype == dict2[k].dtype, f"dtype of {k} is differnt.\n{v.dtype}\n{dict2[k].dtype}"
36-
assert torch.allclose(v, dict2[k], atol=atol), f"{k} is differnt in model1 and model2.\n" + f"{v}\n" + f"{dict2[k]}\n"
36+
assert torch.allclose(v.float(), dict2[k].float(), atol=atol), f"{k} is differnt in model1 and model2.\n" + f"{v}\n" + f"{dict2[k]}\n"
3737

3838

3939
@torch.no_grad()
@@ -72,6 +72,8 @@ def test_save_vllm_compatible_model():
7272
generation_config.save_pretrained("saved_results_qwen")
7373
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
7474
tokenizer.save_pretrained("saved_results_qwen")
75+
shutil.rmtree("saved_results_qwen", ignore_errors=True)
76+
shutil.rmtree("nc_workspace", ignore_errors=True)
7577

7678
@pytest.mark.skip(reason="[SW-226589] Skip this test since the model was updated")
7779
def test_load_model_provided_by_neuralmagic():

0 commit comments

Comments
 (0)