@@ -349,6 +349,8 @@ class PatchedRowParallelLinear(PatchedLinearBase):
349349 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
350350 kwargs ["func_names" ] = ("resolve_input" , )
351351 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
352+ from .._core .vllm_functions import get_vllm_row_parallel_collective_func
353+ self .row_parallel_collective_func = get_vllm_row_parallel_collective_func ()
352354 # TODO [SW-224403]: Enable dynamic quantization in row parallel allreduce
353355 allreduce_quantization_enable = get_hqt_config (mod ).cfg ["row_parallel_linear_allreduce_quantization" ]
354356 if self .quantization_mode in (QuantMode .MEASURE , QuantMode .SHAPE ):
@@ -381,7 +383,7 @@ def forward_qdq(self, input):
381383 output = self .run_linear_qdq (resolved_input , None )
382384
383385 if self .reduce_results :
384- output = self .collective_func (output )
386+ output = self .row_parallel_collective_func (output )
385387 return self .bias_add (output )
386388
387389 def lp_matmul_hp (self , input ):
@@ -402,12 +404,12 @@ def forward_quant_reduce_in_lp(self, input):
402404 if input .shape [1 ] == 1 :
403405 allreduce_output_hp = self .quant_all_reduce_sum (matmul_output_hp )
404406 else :
405- allreduce_output_hp = self .collective_func (matmul_output_hp )
407+ allreduce_output_hp = self .row_parallel_collective_func (matmul_output_hp )
406408 return self .bias_add (allreduce_output_hp )
407409
408410 def forward_quant_reduce_in_hp (self , input ):
409411 matmul_output_hp = self .lp_matmul_hp (input )
410- all_reduce_output_hp = self .collective_func (matmul_output_hp )
412+ all_reduce_output_hp = self .row_parallel_collective_func (matmul_output_hp )
411413 return self .bias_add (all_reduce_output_hp )
412414
413415 def measure_input_and_matmul (self , input ):
@@ -426,7 +428,7 @@ def forward_measure_reduce(self, input):
426428 output = self .measure_input_and_matmul (input )
427429 max_output = output .clone ()
428430 dist .all_reduce (max_output , op = dist .ReduceOp .MAX )
429- all_reduce_output = self .collective_func (output )
431+ all_reduce_output = self .row_parallel_collective_func (output )
430432 measure_output ((max_output , all_reduce_output ,), self ._mod_extra_config .outputs )
431433 return self .bias_add (all_reduce_output )
432434
@@ -478,12 +480,14 @@ class PatchedColumnParallelLinear(PatchedLinearBase):
478480 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
479481 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
480482 self .init_linear (mod_extra_config )
483+ from .._core .vllm_functions import get_vllm_column_parallel_collective_func
484+ self .column_parallel_collective_func = get_vllm_column_parallel_collective_func ()
481485
482486 def forward_qdq (self , input ):
483487 output = self .run_linear_qdq (input , None )
484488 output , output_bias = self .add_bias (output )
485489 if self .gather_output :
486- output = self .collective_func (output )
490+ output = self .column_parallel_collective_func (output )
487491 return output , output_bias
488492
489493 def forward_quant (self , input ):
@@ -492,7 +496,7 @@ def forward_quant(self, input):
492496 dqoutput = self .dequant_output (output )
493497 dqoutput , dqoutput_bias = self .add_bias (dqoutput )
494498 if self .gather_output :
495- dqoutput = self .collective_func (dqoutput )
499+ dqoutput = self .column_parallel_collective_func (dqoutput )
496500 return dqoutput , dqoutput_bias
497501
498502 def forward_measure (self , input ):
@@ -501,7 +505,7 @@ def forward_measure(self, input):
501505 measure_output ((output ,), self ._mod_extra_config .outputs )
502506 output , output_bias = self .add_bias (output )
503507 if self .gather_output :
504- output = self .collective_func (output )
508+ output = self .column_parallel_collective_func (output )
505509 return output , output_bias
506510
507511 def add_bias (self , output ):
0 commit comments