@@ -893,10 +893,15 @@ def _replace_linear_8da4w(
893893 linear_class : Type [torch .nn .Module ],
894894 copy_weights : bool = False ,
895895):
896- for name , child in module .named_children ():
897- if isinstance (child , nn .Linear ):
898- if _check_linear_int4_k (child .in_features , groupsize ) or padding_allowed :
899- new_linear = linear_class (
896+
897+ #import the util function here to avoid circular dependency
898+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
899+
900+ def filter_fn (child : torch .nn .Module , cur_fqn :str ) -> bool :
901+ return isinstance (child , nn .Linear ) and (_check_linear_int4_k (child .in_features , groupsize ) or padding_allowed )
902+
903+ def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
904+ new_linear = linear_class (
900905 child .in_features ,
901906 child .out_features ,
902907 bias = False ,
@@ -905,22 +910,14 @@ def _replace_linear_8da4w(
905910 precision = precision ,
906911 scales_precision = scales_precision ,
907912 )
908- # In distributed training, the model may be instantiated
909- # on the meta device, in which case there is no need to
910- # copy the weights, and doing so will result in an error
911- if copy_weights and child .weight .device != torch .device ("meta" ):
912- new_linear .weight = child .weight
913- setattr (module , name , new_linear )
914- else :
915- _replace_linear_8da4w (
916- child ,
917- groupsize ,
918- padding_allowed ,
919- precision ,
920- scales_precision ,
921- linear_class ,
922- copy_weights ,
923- )
913+ # In distributed training, the model may be instantiated
914+ # on the meta device, in which case there is no need to
915+ # copy the weights, and doing so will result in an error
916+ if copy_weights and child .weight .device != torch .device ("meta" ):
917+ new_linear .weight = child .weight
918+ return new_linear
919+
920+ _replace_with_custom_fn_if_matches_filter (module , replacement_fn , filter_fn )
924921
925922def replace_linear_8da4w (
926923 module : torch .nn .Module ,
0 commit comments