Skip to content

Commit 4677fdd

Browse files
xin3hexinhe3
andauthored
[SW-228570] support FP8 GaudiFluxPipeline save and load (#254)
* [SW-228570] support FP8 GaudiFluxPipeline save and load --------- Signed-off-by: Xin He <[email protected]> Co-authored-by: Xin He <[email protected]>
1 parent 8b81146 commit 4677fdd

File tree

3 files changed

+86
-32
lines changed

3 files changed

+86
-32
lines changed

neural_compressor/torch/algorithms/fp8_quant/save_load.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SaveLoadFormat,
3737
get_enum_from_format,
3838
UNIT_MAPPING,
39+
write_json_file,
3940
)
4041

4142

@@ -390,8 +391,13 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
390391
# Ensure those codes run on a single rank.
391392
configs_mapping = model.qconfig
392393
config_object = configs_mapping[next(iter(configs_mapping))]
393-
update_model_config(model, format, config_object)
394-
model.config.save_pretrained(checkpoint_dir)
394+
config_object.mode = "LOAD"
395+
config_object.world_size = world_size # record world_size for loading
396+
# Flux pipeline has FrozenDict as config
397+
if not isinstance(model.config, dict):
398+
update_model_config(model, format, config_object)
399+
model.config.save_pretrained(checkpoint_dir)
400+
write_json_file(os.path.join(checkpoint_dir, "quantization_config.json"), config_object.to_dict())
395401

396402
if hasattr(model, "generation_config") and model.generation_config is not None:
397403
model.generation_config.save_pretrained(checkpoint_dir)
@@ -405,16 +411,31 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
405411
"""Initialize BF16 model with meta tensor."""
406412
import transformers
407413
from accelerate import init_empty_weights
408-
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
414+
415+
# Handling model objects not in AutoModelForCausalLM
416+
model = kwargs.get("original_model", None)
417+
# Handle Flux pipeline without AutoConfig
418+
try:
419+
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
420+
quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None
421+
hp_dtype = config.torch_dtype
422+
except:
423+
config, hp_dtype = model.config, torch.bfloat16
424+
quantization_config = kwargs.get("quantization_config", None)
425+
setattr(model.config, "quantization_config", quantization_config)
426+
427+
if quantization_config is not None and "hp_dtype" in quantization_config:
428+
hp_dtype = HpDtype[quantization_config["hp_dtype"].upper()].value
429+
409430
# fp8 model provided by neuralmagic.
410431
if (
411-
"quant_method" in config.quantization_config
412-
and config.quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
432+
"quant_method" in quantization_config
433+
and quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
413434
):
414435
from_neuralmagic = True
415436
if (
416-
"kv_cache_scheme" in config.quantization_config
417-
and config.quantization_config["kv_cache_scheme"] is not None
437+
"kv_cache_scheme" in quantization_config
438+
and quantization_config["kv_cache_scheme"] is not None
418439
):
419440
from_neuralmagic_with_kv = True
420441
else:
@@ -431,16 +452,13 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
431452
else:
432453
raise ValueError("Please install optimum-habana to load fp8 kv cache model.")
433454

434-
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
435-
436-
hp_dtype = config.torch_dtype
437-
if hasattr(config, "quantization_config") and "hp_dtype" in config.quantization_config:
438-
hp_dtype = HpDtype[config.quantization_config["hp_dtype"].upper()].value
455+
if model is None:
456+
with init_empty_weights(include_buffers=False):
457+
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
439458
if world_size > 1:
440459
import deepspeed
460+
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
441461

442-
with init_empty_weights(include_buffers=False):
443-
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
444462
# TODO: [SW-199728] [DeepSpeed] Buffers initialized by model are not correct after tensor parallel
445463
# get_non_persistent_buffers and load_non_persistent_buffers are workarounds of [SW-199728]
446464
non_persistent_buffers = get_non_persistent_buffers(model)
@@ -451,16 +469,13 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
451469
model = deepspeed.init_inference(model, **ds_inference_kwargs)
452470
model = model.module
453471
load_non_persistent_buffers(model, non_persistent_buffers)
454-
else:
455-
with init_empty_weights(include_buffers=False):
456-
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
457472
model.to(hp_dtype)
458473

459474
try:
460475
generation_config = transformers.GenerationConfig.from_pretrained(model_name_or_path, **kwargs)
461476
model.generation_config = generation_config
462477
except: # Since model.generation_config is optional, relaxed exceptions can handle more situations.
463-
logger.warning("model.generation_config is not loaded correctly.")
478+
logger.warning("model.generation_config may not be loaded correctly.")
464479
return model, from_neuralmagic, from_neuralmagic_with_kv
465480

466481

@@ -635,7 +650,8 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
635650
model.load_state_dict(rank_state_dict, assign=True, strict=False)
636651
load_scale_params(model, rank_state_dict) # ensure per-channel scale is loaded correctly
637652
clear_quantized_func_wrapper_factory()
638-
model.tie_weights()
653+
if hasattr(model, "tie_weights"):
654+
model.tie_weights()
639655
model = model.to(cur_accelerator.name())
640656
model = model.eval()
641657
cur_accelerator.synchronize()
@@ -745,8 +761,6 @@ def update_model_config(model, format, config_object):
745761
quantization_config = convert_config_to_vllm_compatible(config_object)
746762
model.config.quantization_config = quantization_config
747763
else:
748-
config_object.mode = "LOAD"
749-
config_object.world_size = world_size # record world_size for loading
750764
model.config.quantization_config = config_object
751765

752766

neural_compressor/torch/quantization/save_load_entry.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
RTNConfig,
2727
TEQConfig,
2828
)
29-
from neural_compressor.torch.utils import SaveLoadFormat, get_enum_from_format
29+
from neural_compressor.torch.utils import SaveLoadFormat, get_enum_from_format, read_json_file
3030

3131
config_name_mapping = {
3232
FP8_QUANT: FP8Config,
@@ -111,6 +111,8 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
111111
from neural_compressor.common.base_config import ConfigRegistry
112112

113113
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json")
114+
if not os.path.exists(qconfig_file_path):
115+
raise ValueError(f"qconfig.json file is necessary for the default format.")
114116
with open(qconfig_file_path, "r") as f:
115117
per_op_qconfig = json.load(f)
116118

@@ -138,22 +140,46 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
138140
return qmodel.to(device)
139141
elif format == SaveLoadFormat.HUGGINGFACE:
140142
import transformers
141-
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
143+
144+
try:
145+
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
146+
quantization_config = config.quantization_config
147+
except:
148+
quantization_config_file = "quantization_config.json"
149+
# for Flux pipeline
150+
if os.path.exists(model_name_or_path):
151+
# If the model_name_or_path is a local path, try to load the config from there
152+
quantization_config_path = os.path.join(model_name_or_path, quantization_config_file)
153+
else:
154+
# If the model_name_or_path is a Hugging Face model ID, try to download the config
155+
from huggingface_hub import hf_hub_download
156+
157+
quantization_config_path = hf_hub_download(
158+
repo_id=model_name_or_path,
159+
filename=quantization_config_file,
160+
revision=kwargs.get("revision", "main"),
161+
)
162+
quantization_config = read_json_file(quantization_config_path)
163+
kwargs["quantization_config"] = quantization_config
164+
165+
if original_model is not None:
166+
kwargs["original_model"] = original_model
142167
# use config to check which algorithm is used.
143168
if (
144-
"fp8_config" in config.quantization_config or
169+
"fp8_config" in quantization_config or
145170
# for FP8 LLMs for vLLM (https://huggingface.co/neuralmagic).
146171
(
147-
"quant_method" in config.quantization_config and
148-
config.quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
172+
"quant_method" in quantization_config and
173+
quantization_config["quant_method"] in ["fp8", "compressed-tensors"]
149174
)
150175
):
151176
from neural_compressor.torch.algorithms import fp8_quant
177+
152178
return fp8_quant.load(model_name_or_path, format=format, device=device, **kwargs)
153179
else:
154180
from neural_compressor.torch.algorithms import weight_only
155181

156182
qmodel = weight_only.load(model_name_or_path, format=SaveLoadFormat.HUGGINGFACE, device=device, **kwargs)
157183
return qmodel.to(device)
158184
else:
159-
assert False, "This code path should never be reached."
185+
assert False, "Unexpected format: {} occurred during model loading".format(format)

neural_compressor/torch/utils/utility.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414
"""Intel Neural Compressor PyTorch utilities."""
1515

16-
16+
import os
17+
import json
1718
import enum
1819
import importlib
1920
from collections import UserDict
@@ -311,8 +312,6 @@ def get_processor_type_from_user_config(user_processor_type: Optional[Union[str,
311312

312313
def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
313314
"""Download hugging face model from hf hub."""
314-
import os
315-
316315
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
317316
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
318317
from huggingface_hub.utils import EntryNotFoundError
@@ -356,8 +355,6 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
356355

357356
def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs):
358357
"""Load a empty model."""
359-
import os
360-
361358
from accelerate import init_empty_weights
362359
from transformers import AutoConfig, AutoModelForCausalLM
363360
from transformers.models.auto.auto_factory import _BaseAutoModelClass
@@ -741,3 +738,20 @@ def get_enum_from_format(format):
741738
return obj
742739
raise ValueError(
743740
f"Invalid format value ('{format}'). Enter one of [{[m.name for m in SaveLoadFormat]}]")
741+
742+
743+
def read_json_file(file_path):
744+
"""Read a JSON file and return its content."""
745+
if not file_path or not os.path.exists(file_path):
746+
raise FileNotFoundError(f"File {file_path} does not exist.")
747+
with open(file_path, "r", encoding="utf-8") as f:
748+
return json.load(f)
749+
750+
751+
def write_json_file(file_path, data):
752+
"""Write data to a JSON file."""
753+
if not file_path:
754+
raise ValueError("File path cannot be empty.")
755+
with open(file_path, "w", encoding="utf-8") as f:
756+
json.dump(data, f, indent=4, ensure_ascii=False)
757+
logger.info(f"Data written to {file_path} successfully.")

0 commit comments

Comments
 (0)