Skip to content

Commit 83a74cc

Browse files
authored
[SW-230951] Save measurements according to samples counter (#251)
* Added post forward hook to dump measurements according to samples counter * add support in samples counter in config * removed function in RowParllelLinear as it is removed from the vllm upstream code * currently only blocking method is operational, will complete async methods in future commit * fix CR comments * remove unused files * add reslove_input method it can't be defined in vllm due to upstream considerations, so it is copied here * fixed logging acoording to cr * fixed resolve_input and moved the hook function
1 parent ffd26cb commit 83a74cc

File tree

11 files changed

+281
-97
lines changed

11 files changed

+281
-97
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def load_npz(fname):
7474
return d["arr_0"].item()
7575

7676

77-
def save_file(model, d, source_format, fname, mode):
77+
def save_file(model, d, source_format, fname, mode, num_samples=0):
7878
from .._quant_common.quant_config import get_hqt_config
7979
config = get_hqt_config(model)
8080
logger.debug("Saving %s file: %s", mode, fname)
@@ -87,6 +87,8 @@ def save_file(model, d, source_format, fname, mode):
8787
"Mode": mode,
8888
"Nodes": dc,
8989
}
90+
if num_samples > 0:
91+
df = { "NumSamples": num_samples, **df}
9092
try:
9193
file_functions[ext]['save'](df, fname)
9294
except:

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 9 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
16-
import os
17-
1815
import numpy as np
1916
import torch
2017

2118
from abc import abstractmethod
2219

2320
from .._quant_common.quant_config import MeasureExclude, QuantMode, get_hqt_config, set_hqt_config
21+
from .save_measure import gmod_list
2422
from .scale_methods.scale_method_config import ScaleMethodString
2523
from ..utils.logger import logger
2624
from .common import load_file, save_file, ShapeList
@@ -33,10 +31,9 @@
3331
IMOD_DICT,
3432
)
3533
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
36-
cur_accelerator = auto_detect_accelerator()
3734

3835

39-
gmod_list = []
36+
cur_accelerator = auto_detect_accelerator()
4037

4138

4239
def patch_module_measure(mod, mconfig, mod_dict):
@@ -115,6 +112,12 @@ def prepare_model(model, mod_list=None):
115112
generate_model_info(model)
116113
register_patched_measure_modules(model, mod_list, observer_class, d_shapes)
117114

115+
def setup_calibration_counter(model, config):
116+
# used for automatically dumping measurements
117+
calibration_sample_interval = int(config["calibration_sample_interval"])
118+
if calibration_sample_interval > 0:
119+
from .save_measure import add_calibration_samples_counter
120+
add_calibration_samples_counter(model, calibration_sample_interval)
118121

119122
def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=None):
120123
"""Replace the submodules of the model that appear in mod_list with a patched submodule that uses the given observer_class
@@ -129,6 +132,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
129132
"""
130133
top_level_config = get_hqt_config(model)
131134
config = top_level_config.cfg
135+
setup_calibration_counter(model, config)
132136
skip_outputs_measurements = config["measure_exclude"] & (MeasureExclude.OUTPUT | MeasureExclude.ALL)
133137
patched_types = set()
134138
non_patched_types = set()
@@ -187,94 +191,6 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
187191
cur_accelerator.synchronize()
188192

189193

190-
def is_measure_done(mod_extra_config):
191-
# check if measurements were collected by observer
192-
for obs in ([] if mod_extra_config.inputs is None else mod_extra_config.inputs) + (
193-
[] if mod_extra_config.outputs is None else mod_extra_config.outputs
194-
):
195-
if obs.is_used():
196-
return True
197-
return False
198-
199-
200-
def get_mod_extra_config_dict(model):
201-
mcd = {}
202-
for name, mod in model.named_modules():
203-
if hasattr(mod, "_mod_extra_config") and mod._mod_extra_config:
204-
if is_measure_done(mod._mod_extra_config):
205-
name = name.replace("_orig_mod.", "") # remove _orig_mod part added by dynamo mechanism
206-
mcd[name] = mod._mod_extra_config
207-
else:
208-
logger.debug(
209-
"Layer '%s' has no measurements therefore it can't be quantized during quantization.",
210-
name,
211-
)
212-
return mcd
213-
214-
215-
def measure_control_to_state_dict(mcd):
216-
sd = {}
217-
sdl = {}
218-
for mname in mcd:
219-
sd[mname] = dict()
220-
sdl[mname] = dict()
221-
sd[mname]["inputs"] = [
222-
mcd[mname].inputs[i].state.detach().cpu().float().numpy()
223-
for i in range(len(mcd[mname].inputs))
224-
if mcd[mname].inputs[i].state is not None
225-
]
226-
sdl[mname]["inputs"] = [
227-
mcd[mname].inputs[i].state.detach().cpu().float().numpy().tolist()
228-
for i in range(len(mcd[mname].inputs))
229-
if mcd[mname].inputs[i].state is not None
230-
]
231-
if mcd[mname].outputs:
232-
sd[mname]["outputs"] = [
233-
mcd[mname].outputs[i].state.detach().cpu().float().numpy()
234-
for i in range(len(mcd[mname].outputs))
235-
if mcd[mname].outputs[i].state is not None
236-
]
237-
sdl[mname]["outputs"] = [
238-
mcd[mname].outputs[i].state.detach().cpu().float().numpy().tolist()
239-
for i in range(len(mcd[mname].outputs))
240-
if mcd[mname].outputs[i].state is not None
241-
]
242-
if len(mcd[mname].params) > 0:
243-
sd[mname]["params"] = dict()
244-
sdl[mname]["params"] = dict()
245-
for param_name in mcd[mname].params:
246-
if mcd[mname].params[param_name].state is not None:
247-
sd[mname]["params"][param_name] = mcd[mname].params[param_name].state.detach().cpu().float().numpy()
248-
sdl[mname]["params"][param_name] = (
249-
mcd[mname].params[param_name].state.detach().cpu().float().numpy().tolist()
250-
)
251-
return sd, sdl
252-
253-
254-
def save_measurements(model, fname=None):
255-
config = get_hqt_config(model).cfg
256-
if config["mode"] in [QuantMode.MEASURE, QuantMode.SHAPE]:
257-
if fname is None:
258-
if ("measure_file" in config) and (config["measure_file"] is not None):
259-
fname_base = config["measure_file"]
260-
measure_type = "DynamicRange"
261-
elif ("shape_file" in config) and (config["shape_file"] is not None) and (config["observer"] == "shape"):
262-
fname_base = config["shape_file"]
263-
measure_type = "Shape"
264-
fname_np = fname_base + ".npz"
265-
fname_list = fname_base + ".json"
266-
else:
267-
logger.warning("'fname' is not None - Measurements/Shapes will not be saved")
268-
return
269-
mcd = get_mod_extra_config_dict(model)
270-
sd, sdl = measure_control_to_state_dict(mcd)
271-
272-
logger.info("Dumping measurements")
273-
save_file(model, sd, np.ndarray, fname_np, measure_type)
274-
save_file(model, sdl, list, fname_list, measure_type)
275-
save_json(gmod_list, fname_base + "_mod_list.json")
276-
277-
278194
def load_measurements(model, fname):
279195
config = get_hqt_config(model).cfg
280196
source_fname = fname if fname is not None else config["measure_file"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from .save_files import save_measurements, gmod_list
17+
from .hook_logic import add_calibration_samples_counter
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import asyncio
2+
import os
3+
from threading import Thread
4+
5+
from neural_compressor.torch.algorithms.fp8_quant.utils.logger import logger
6+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config
7+
from .save_files import (
8+
create_files_names, measure_control_to_state_dict, save_measurements_files, save_measurements,
9+
get_mod_extra_config_dict, gmod_list )
10+
11+
12+
def dump_direct_call(model):
13+
save_measurements(model)
14+
15+
def dump_threading(model):
16+
t = Thread(target=save_measurements, args=(model,), daemon=True)
17+
t.start()
18+
19+
def dump_async_call(model):
20+
asyncio.run(dump_async_call_inner(model))
21+
22+
async def dump_async_call_inner(model):
23+
mcd = get_mod_extra_config_dict(model)
24+
await save_measurements_async_wrapper(model, mcd)
25+
26+
async def save_measurements_async_wrapper(model, mcd):
27+
config = get_hqt_config(model).cfg
28+
fname_base, fname_np, fname_list, measure_type = create_files_names(config)
29+
sd, sdl = measure_control_to_state_dict(mcd)
30+
save_measurements_files(model, sd, sdl, gmod_list, fname_np, fname_list, fname_base, measure_type)
31+
32+
33+
def dump_shelv(model):
34+
pass
35+
36+
_measurement_dump_method = os.getenv("MEASUREMENT_DUMP_METHOD", "1")
37+
_measurement_dump_method_dict = {
38+
"1": dump_direct_call,
39+
# below methods shouldn't be currently used as they are not fully completed
40+
"2": dump_threading,
41+
"3": dump_async_call,
42+
"5": dump_shelv
43+
}
44+
45+
46+
def _increment_calibration_samples_counter(model, *args): # post hook function
47+
model.calibration_samples_counter += 1
48+
if model.calibration_samples_counter % model.calibration_sample_interval == 0:
49+
logger.debug("Reached sampling interval limit: %d, total samples: %d, dumping measurements.",
50+
model.calibration_sample_interval, model.calibration_samples_counter)
51+
_measurement_dump_method_dict[_measurement_dump_method](model)
52+
logger.debug("finished dumping measurements.")
53+
54+
def add_calibration_samples_counter(model_to_calibrate, calibration_sample_interval):
55+
"""
56+
Adds a forward post-hook to the model that counts the number of calibration samples processed.
57+
When the maximum number of samples is reached, it saves the measurements.
58+
"""
59+
model_to_calibrate.calibration_samples_counter = 0
60+
model_to_calibrate.calibration_sample_interval = calibration_sample_interval
61+
model_to_calibrate.register_forward_hook(_increment_calibration_samples_counter)
62+
logger.info("Calibration samples interval added to the model - %d.", calibration_sample_interval)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import sys
2+
3+
from ..measure import save_measurements
4+
5+
if __name__ == "__main__":
6+
model = sys.argv[0]
7+
save_measurements(model)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from neural_compressor.torch.algorithms.fp8_quant.utils.logger import logger
2+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config, QuantMode
3+
from neural_compressor.torch.algorithms.fp8_quant._core.common import save_file, save_json
4+
5+
6+
def is_measure_done(mod_extra_config):
7+
# check if measurements were collected by observer
8+
for obs in ([] if mod_extra_config.inputs is None else mod_extra_config.inputs) + (
9+
[] if mod_extra_config.outputs is None else mod_extra_config.outputs
10+
):
11+
if obs.is_used():
12+
return True
13+
return False
14+
15+
def get_mod_extra_config_dict(model):
16+
mcd = {}
17+
for name, mod in model.named_modules():
18+
if hasattr(mod, "_mod_extra_config") and mod._mod_extra_config:
19+
if is_measure_done(mod._mod_extra_config):
20+
name = name.replace("_orig_mod.", "") # remove _orig_mod part added by dynamo mechanism
21+
mcd[name] = mod._mod_extra_config
22+
else:
23+
logger.debug(
24+
"Layer '%s' has no measurements therefore it can't be quantized during quantization.",
25+
name,
26+
)
27+
return mcd
28+
29+
def measure_control_to_state_dict(mcd):
30+
sd = {}
31+
sdl = {}
32+
for mname in mcd:
33+
sd[mname] = dict()
34+
sdl[mname] = dict()
35+
sd[mname]["inputs"] = [
36+
mcd[mname].inputs[i].state.detach().cpu().float().numpy()
37+
for i in range(len(mcd[mname].inputs))
38+
if mcd[mname].inputs[i].state is not None
39+
]
40+
sdl[mname]["inputs"] = [
41+
mcd[mname].inputs[i].state.detach().cpu().float().numpy().tolist()
42+
for i in range(len(mcd[mname].inputs))
43+
if mcd[mname].inputs[i].state is not None
44+
]
45+
if mcd[mname].outputs:
46+
sd[mname]["outputs"] = [
47+
mcd[mname].outputs[i].state.detach().cpu().float().numpy()
48+
for i in range(len(mcd[mname].outputs))
49+
if mcd[mname].outputs[i].state is not None
50+
]
51+
sdl[mname]["outputs"] = [
52+
mcd[mname].outputs[i].state.detach().cpu().float().numpy().tolist()
53+
for i in range(len(mcd[mname].outputs))
54+
if mcd[mname].outputs[i].state is not None
55+
]
56+
if len(mcd[mname].params) > 0:
57+
sd[mname]["params"] = dict()
58+
sdl[mname]["params"] = dict()
59+
for param_name in mcd[mname].params:
60+
if mcd[mname].params[param_name].state is not None:
61+
sd[mname]["params"][param_name] = mcd[mname].params[param_name].state.detach().cpu().float().numpy()
62+
sdl[mname]["params"][param_name] = (
63+
mcd[mname].params[param_name].state.detach().cpu().float().numpy().tolist()
64+
)
65+
return sd, sdl
66+
67+
def create_files_names(config, fname = None):
68+
if fname is None:
69+
if ("measure_file" in config) and (config["measure_file"] is not None):
70+
fname_base = config["measure_file"]
71+
measure_type = "DynamicRange"
72+
elif ("shape_file" in config) and (config["shape_file"] is not None) and (config["observer"] == "shape"):
73+
fname_base = config["shape_file"]
74+
measure_type = "Shape"
75+
fname_np = fname_base + ".npz"
76+
fname_list = fname_base + ".json"
77+
return fname_base, fname_np, fname_list, measure_type
78+
else:
79+
logger.warning("'fname' is not None - Measurements/Shapes will not be saved")
80+
return
81+
82+
def save_measurements_files(model, state_dict, state_list, gmod_list, fname_np, fname_list, fname_base, measure_type,
83+
num_samples=0):
84+
import numpy as np
85+
logger.info("Dumping measurements")
86+
save_file(model, state_dict, np.ndarray, fname_np, measure_type, num_samples)
87+
save_file(model, state_list, list, fname_list, measure_type, num_samples)
88+
save_json(gmod_list, fname_base + "_mod_list.json")
89+
return
90+
91+
92+
gmod_list = [] # global list extened with patched modules in measure.prepare_model
93+
94+
95+
def save_measurements(model, fname=None):
96+
config = get_hqt_config(model).cfg
97+
if config["mode"] in [QuantMode.MEASURE, QuantMode.SHAPE]:
98+
fname_base, fname_np, fname_list, measure_type = create_files_names(config, fname)
99+
mcd = get_mod_extra_config_dict(model)
100+
sd, sdl = measure_control_to_state_dict(mcd)
101+
num_samples = model.calibration_samples_counter if hasattr(model, "calibration_samples_counter") else 0
102+
save_measurements_files(model, sd, sdl, gmod_list, fname_np, fname_list, fname_base, measure_type, num_samples)

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def post_all_reduce(self, input):
347347

348348
class PatchedRowParallelLinear(PatchedLinearBase):
349349
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
350-
kwargs["func_names"] = ("resolve_input", )
351350
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
352351
from .._core.vllm_functions import get_vllm_row_parallel_collective_func
353352
self.row_parallel_collective_func = get_vllm_row_parallel_collective_func()
@@ -377,6 +376,19 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
377376
from torch import distributed as dist
378377
self.world_size = dist.get_world_size()
379378

379+
def resolve_input(self, input_):
380+
"""
381+
this code is copied from vllm RowParallelLinear forward method
382+
"""
383+
if self.input_is_parallel:
384+
input_parallel = input_
385+
else:
386+
tp_rank = get_tensor_model_parallel_rank()
387+
splitted_input = split_tensor_along_last_dim(
388+
input_, num_partitions=self.tp_size)
389+
input_parallel = splitted_input[tp_rank].contiguous()
390+
return input_parallel
391+
380392
def forward_qdq(self, input):
381393
# TODO: [SW-208441] Support all_reduce_fp8 in forward_qdq in PatchedRowParallelLinear
382394
resolved_input = self.resolve_input(input)

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
184184
"scale_format": ScaleFormat.SCALAR,
185185
"measure_on_hpu": True, # Determines whether to measure model on hpu device.
186186
"row_parallel_linear_allreduce_quantization" : False, # Turn on/off fp8 allreduce optimization detailed in SW-207602
187-
"dynamic_quantization" : False # Turn on/off fp8 dynamic quantization
187+
"dynamic_quantization" : False, # Turn on/off fp8 dynamic quantization
188+
"calibration_sample_interval" : 0 # number of samples to process before dumping measurements, 0 means no automatic dumping
188189
}
189190
# go over all user-defined keys from json, handle various cases
190191
for keys in custom_config:

0 commit comments

Comments
 (0)