Skip to content

Commit 56e8882

Browse files
ulivnexinhe3
authored andcommitted
[SW-199936] Remove collective_func usage as preparation for vLLM upstream (#249)
self.collective_func is not allowed for upstream, therefore we use explicitly vLLM collective functions, while protecting from import errors and circular import
1 parent 5812c75 commit 56e8882

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# File contains logic for importing vLLM distributed functions, such as we import them only when needed and possible,
17+
# and protect from import errors in case vLLM is not installed.
18+
19+
try:
20+
from vllm.distributed import (tensor_model_parallel_all_gather,
21+
tensor_model_parallel_all_reduce)
22+
except ImportError:
23+
tensor_model_parallel_all_gather = None
24+
tensor_model_parallel_all_reduce = None
25+
26+
def get_vllm_row_parallel_collective_func():
27+
assert tensor_model_parallel_all_gather is not None, "Couldn't import vllm function tensor_model_parallel_all_gather"
28+
return tensor_model_parallel_all_gather
29+
30+
def get_vllm_column_parallel_collective_func():
31+
assert tensor_model_parallel_all_reduce is not None, "Couldn't import vllm function tensor_model_parallel_all_reduce"
32+
return tensor_model_parallel_all_reduce

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ class PatchedRowParallelLinear(PatchedLinearBase):
349349
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
350350
kwargs["func_names"] = ("resolve_input", )
351351
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
352+
from .._core.vllm_functions import get_vllm_row_parallel_collective_func
353+
self.row_parallel_collective_func = get_vllm_row_parallel_collective_func()
352354
# TODO [SW-224403]: Enable dynamic quantization in row parallel allreduce
353355
allreduce_quantization_enable = get_hqt_config(mod).cfg["row_parallel_linear_allreduce_quantization"]
354356
if self.quantization_mode in (QuantMode.MEASURE, QuantMode.SHAPE):
@@ -381,7 +383,7 @@ def forward_qdq(self, input):
381383
output = self.run_linear_qdq(resolved_input, None)
382384

383385
if self.reduce_results:
384-
output = self.collective_func(output)
386+
output = self.row_parallel_collective_func(output)
385387
return self.bias_add(output)
386388

387389
def lp_matmul_hp(self, input):
@@ -402,12 +404,12 @@ def forward_quant_reduce_in_lp(self, input):
402404
if input.shape[1] == 1:
403405
allreduce_output_hp = self.quant_all_reduce_sum(matmul_output_hp)
404406
else:
405-
allreduce_output_hp = self.collective_func(matmul_output_hp)
407+
allreduce_output_hp = self.row_parallel_collective_func(matmul_output_hp)
406408
return self.bias_add(allreduce_output_hp)
407409

408410
def forward_quant_reduce_in_hp(self, input):
409411
matmul_output_hp = self.lp_matmul_hp(input)
410-
all_reduce_output_hp = self.collective_func(matmul_output_hp)
412+
all_reduce_output_hp = self.row_parallel_collective_func(matmul_output_hp)
411413
return self.bias_add(all_reduce_output_hp)
412414

413415
def measure_input_and_matmul(self, input):
@@ -426,7 +428,7 @@ def forward_measure_reduce(self, input):
426428
output = self.measure_input_and_matmul(input)
427429
max_output = output.clone()
428430
dist.all_reduce(max_output, op=dist.ReduceOp.MAX)
429-
all_reduce_output = self.collective_func(output)
431+
all_reduce_output = self.row_parallel_collective_func(output)
430432
measure_output((max_output, all_reduce_output,), self._mod_extra_config.outputs)
431433
return self.bias_add(all_reduce_output)
432434

@@ -478,12 +480,14 @@ class PatchedColumnParallelLinear(PatchedLinearBase):
478480
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
479481
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
480482
self.init_linear(mod_extra_config)
483+
from .._core.vllm_functions import get_vllm_column_parallel_collective_func
484+
self.column_parallel_collective_func = get_vllm_column_parallel_collective_func()
481485

482486
def forward_qdq(self, input):
483487
output = self.run_linear_qdq(input, None)
484488
output, output_bias = self.add_bias(output)
485489
if self.gather_output:
486-
output = self.collective_func(output)
490+
output = self.column_parallel_collective_func(output)
487491
return output, output_bias
488492

489493
def forward_quant(self, input):
@@ -492,7 +496,7 @@ def forward_quant(self, input):
492496
dqoutput = self.dequant_output(output)
493497
dqoutput, dqoutput_bias = self.add_bias(dqoutput)
494498
if self.gather_output:
495-
dqoutput = self.collective_func(dqoutput)
499+
dqoutput = self.column_parallel_collective_func(dqoutput)
496500
return dqoutput, dqoutput_bias
497501

498502
def forward_measure(self, input):
@@ -501,7 +505,7 @@ def forward_measure(self, input):
501505
measure_output((output,), self._mod_extra_config.outputs)
502506
output, output_bias = self.add_bias(output)
503507
if self.gather_output:
504-
output = self.collective_func(output)
508+
output = self.column_parallel_collective_func(output)
505509
return output, output_bias
506510

507511
def add_bias(self, output):

0 commit comments

Comments
 (0)