3636 SaveLoadFormat ,
3737 get_enum_from_format ,
3838 UNIT_MAPPING ,
39+ write_json_file ,
3940)
4041
4142
@@ -390,8 +391,13 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
390391 # Ensure those codes run on a single rank.
391392 configs_mapping = model .qconfig
392393 config_object = configs_mapping [next (iter (configs_mapping ))]
393- update_model_config (model , format , config_object )
394- model .config .save_pretrained (checkpoint_dir )
394+ config_object .mode = "LOAD"
395+ config_object .world_size = world_size # record world_size for loading
396+ # Flux pipeline has FrozenDict as config
397+ if not isinstance (model .config , dict ):
398+ update_model_config (model , format , config_object )
399+ model .config .save_pretrained (checkpoint_dir )
400+ write_json_file (os .path .join (checkpoint_dir , "quantization_config.json" ), config_object .to_dict ())
395401
396402 if hasattr (model , "generation_config" ) and model .generation_config is not None :
397403 model .generation_config .save_pretrained (checkpoint_dir )
@@ -405,16 +411,31 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
405411 """Initialize BF16 model with meta tensor."""
406412 import transformers
407413 from accelerate import init_empty_weights
408- config = transformers .AutoConfig .from_pretrained (model_name_or_path , ** kwargs )
414+
415+ # Handling model objects not in AutoModelForCausalLM
416+ model = kwargs .get ("original_model" , None )
417+ # Handle Flux pipeline without AutoConfig
418+ try :
419+ config = transformers .AutoConfig .from_pretrained (model_name_or_path , ** kwargs )
420+ quantization_config = config .quantization_config if hasattr (config , "quantization_config" ) else None
421+ hp_dtype = config .torch_dtype
422+ except :
423+ config , hp_dtype = model .config , torch .bfloat16
424+ quantization_config = kwargs .get ("quantization_config" , None )
425+ setattr (model .config , "quantization_config" , quantization_config )
426+
427+ if quantization_config is not None and "hp_dtype" in quantization_config :
428+ hp_dtype = HpDtype [quantization_config ["hp_dtype" ].upper ()].value
429+
409430 # fp8 model provided by neuralmagic.
410431 if (
411- "quant_method" in config . quantization_config
412- and config . quantization_config ["quant_method" ] in ["fp8" , "compressed-tensors" ]
432+ "quant_method" in quantization_config
433+ and quantization_config ["quant_method" ] in ["fp8" , "compressed-tensors" ]
413434 ):
414435 from_neuralmagic = True
415436 if (
416- "kv_cache_scheme" in config . quantization_config
417- and config . quantization_config ["kv_cache_scheme" ] is not None
437+ "kv_cache_scheme" in quantization_config
438+ and quantization_config ["kv_cache_scheme" ] is not None
418439 ):
419440 from_neuralmagic_with_kv = True
420441 else :
@@ -431,16 +452,13 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
431452 else :
432453 raise ValueError ("Please install optimum-habana to load fp8 kv cache model." )
433454
434- from neural_compressor .torch .utils import get_non_persistent_buffers , load_non_persistent_buffers
435-
436- hp_dtype = config .torch_dtype
437- if hasattr (config , "quantization_config" ) and "hp_dtype" in config .quantization_config :
438- hp_dtype = HpDtype [config .quantization_config ["hp_dtype" ].upper ()].value
455+ if model is None :
456+ with init_empty_weights (include_buffers = False ):
457+ model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = hp_dtype )
439458 if world_size > 1 :
440459 import deepspeed
460+ from neural_compressor .torch .utils import get_non_persistent_buffers , load_non_persistent_buffers
441461
442- with init_empty_weights (include_buffers = False ):
443- model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = hp_dtype )
444462 # TODO: [SW-199728] [DeepSpeed] Buffers initialized by model are not correct after tensor parallel
445463 # get_non_persistent_buffers and load_non_persistent_buffers are workarounds of [SW-199728]
446464 non_persistent_buffers = get_non_persistent_buffers (model )
@@ -451,16 +469,13 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
451469 model = deepspeed .init_inference (model , ** ds_inference_kwargs )
452470 model = model .module
453471 load_non_persistent_buffers (model , non_persistent_buffers )
454- else :
455- with init_empty_weights (include_buffers = False ):
456- model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = hp_dtype )
457472 model .to (hp_dtype )
458473
459474 try :
460475 generation_config = transformers .GenerationConfig .from_pretrained (model_name_or_path , ** kwargs )
461476 model .generation_config = generation_config
462477 except : # Since model.generation_config is optional, relaxed exceptions can handle more situations.
463- logger .warning ("model.generation_config is not loaded correctly." )
478+ logger .warning ("model.generation_config may not be loaded correctly." )
464479 return model , from_neuralmagic , from_neuralmagic_with_kv
465480
466481
@@ -635,7 +650,8 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
635650 model .load_state_dict (rank_state_dict , assign = True , strict = False )
636651 load_scale_params (model , rank_state_dict ) # ensure per-channel scale is loaded correctly
637652 clear_quantized_func_wrapper_factory ()
638- model .tie_weights ()
653+ if hasattr (model , "tie_weights" ):
654+ model .tie_weights ()
639655 model = model .to (cur_accelerator .name ())
640656 model = model .eval ()
641657 cur_accelerator .synchronize ()
@@ -745,8 +761,6 @@ def update_model_config(model, format, config_object):
745761 quantization_config = convert_config_to_vllm_compatible (config_object )
746762 model .config .quantization_config = quantization_config
747763 else :
748- config_object .mode = "LOAD"
749- config_object .world_size = world_size # record world_size for loading
750764 model .config .quantization_config = config_object
751765
752766
0 commit comments