@@ -2873,78 +2873,194 @@ def cuda(self) -> bool:
28732873
28742874
28752875@register_quantize_op
2876- class MXFP8StackedGroupedGemm (QuantizeOpBase ):
2876+ class MXFP8GroupedGemm2d3d (QuantizeOpBase ):
28772877 """
2878- MXFP8 grouped matmul with blockwise scaling and stacked inputs .
2878+ MXFP8 grouped GEMM with 2D inputs and 3D weights .
28792879 """
28802880
28812881 def preprocess (self , x , w ):
2882- m_values = [i .shape [0 ] for i in x ]
2883- m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2882+ assert isinstance (x , list )
2883+ assert isinstance (w , list )
2884+ x = torch .cat (x , dim = 0 ).contiguous () # (G * M, K)
2885+ w = torch .stack (w , dim = 0 ).contiguous () # (G, N, K)
2886+ return x , w
2887+
2888+ def quantize (self , x , w ):
2889+ block_size = 32
2890+ G , N , K = w .shape
2891+ total_M = x .shape [0 ]
2892+ group_size = total_M // G
2893+ input_group_end_offsets = torch .arange (
2894+ group_size , total_M + 1 , group_size , dtype = torch .int32 , device = x .device
2895+ )
2896+
2897+ # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
2898+ # as they each used for independent gemm in the grouped gemm.
28842899 wq_list = []
28852900 w_scale_list = []
2886- for i in range (m_sizes . shape [ 0 ] ):
2901+ for i in range (G ):
28872902 w_scale , wq = to_mxfp8 (w [i ])
28882903 w_scale = _to_blocked (w_scale )
28892904 wq_list .append (wq )
28902905 w_scale_list .append (w_scale )
28912906 wq = torch .stack (wq_list , dim = 0 ).contiguous ()
28922907 w_scale = torch .stack (w_scale_list , dim = 0 ).contiguous ()
2893- return x , wq , w_scale , m_sizes
28942908
2895- def quantize ( self , x , wq , w_scale , m_sizes ):
2896- starting_row_after_padding_list = [ 0 ]
2909+ # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
2910+ # as they each used for independent gemm in the grouped gemm.
28972911 xq_list = []
28982912 x_scale_list = []
2899- for i in range (m_sizes .shape [0 ]):
2900- scale_slice = x [i ]
2901- if m_sizes [i ].item () != 0 :
2902- x_scale , xq = to_mxfp8 (scale_slice )
2913+ for i in range (G ):
2914+ prev_group_end = 0 if i == 0 else input_group_end_offsets [i - 1 ]
2915+ curr_group_end = input_group_end_offsets [i ]
2916+ group_size = curr_group_end - prev_group_end
2917+ if group_size > 0 :
2918+ x_slice = x [prev_group_end :curr_group_end , :]
2919+ x_scale , xq = to_mxfp8 (x_slice )
29032920 x_scale = _to_blocked (x_scale )
29042921 xq_list .append (xq )
29052922 x_scale_list .append (x_scale )
2906- starting_row_after_padding_list .append (
2907- starting_row_after_padding_list [i ]
2908- + x_scale .numel () // (x [0 ].shape [1 ] // 32 )
2909- )
2910- else :
2911- starting_row_after_padding_list .append (
2912- starting_row_after_padding_list [i ]
2913- )
29142923 xq = torch .cat (xq_list , dim = 0 ).contiguous ()
29152924 x_scale = torch .cat (x_scale_list , dim = 0 ).contiguous ()
2916- x_scale = x_scale .reshape (- 1 , x [ 0 ]. shape [ - 1 ] // 32 )
2925+ x_scale = x_scale .reshape (- 1 , K // block_size )
29172926 xq = xq .view (- 1 , xq .shape [- 1 ])
2918- return (
2927+ return xq , wq , x_scale , w_scale , input_group_end_offsets
2928+
2929+ def compute (self , xq , wq , x_scale , w_scale , input_group_end_offsets ):
2930+ return torch .ops .fbgemm .mx8mx8bf16_grouped_mm (
29192931 xq ,
2920- wq ,
2932+ wq . transpose ( - 2 , - 1 ) ,
29212933 x_scale ,
29222934 w_scale ,
2923- m_sizes ,
2924- torch .tensor (starting_row_after_padding_list , device = xq .device ),
2935+ input_group_end_offsets ,
29252936 )
29262937
2927- def compute (self , xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding ):
2928- return torch .ops .fbgemm .mx8mx8bf16_grouped_stacked (
2938+ def quantize_and_compute (self , x , w ):
2939+ xq , wq , x_scale , w_scale , input_group_end_offsets = self .quantize (x , w )
2940+ return self .compute (
29292941 xq ,
29302942 wq ,
29312943 x_scale ,
29322944 w_scale ,
2933- m_sizes ,
2934- starting_row_after_padding = starting_row_after_padding ,
2945+ input_group_end_offsets ,
29352946 )
29362947
2937- def quantize_and_compute (self , x , w ):
2938- xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding = self .quantize (
2939- x , w
2948+ @property
2949+ def name (self ) -> str :
2950+ return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
2951+
2952+ @property
2953+ def hip (self ) -> bool :
2954+ return False
2955+
2956+ @property
2957+ def cuda (self ) -> bool :
2958+ return True
2959+
2960+
2961+ @register_quantize_op
2962+ class MXFP8GroupedGemm2d2d (QuantizeOpBase ):
2963+ """
2964+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
2965+ """
2966+
2967+ def preprocess (self , x , w ):
2968+ assert isinstance (x , list )
2969+ assert isinstance (w , list )
2970+ G = len (x )
2971+ x = torch .cat (x , dim = 1 ).contiguous () # (M, total_K)
2972+ w = torch .cat (w , dim = 1 ).contiguous () # (N, total_K)
2973+ return x , w , G
2974+
2975+ def quantize (self , x , w , G ):
2976+ # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
2977+ # where we use "K" as the contracting dim which has "G" groups.
2978+ M , total_K = x .shape
2979+ N , _ = w .shape
2980+ group_size = total_K // G
2981+ input_group_end_offsets = torch .arange (
2982+ group_size , total_K + 1 , group_size , dtype = torch .int32 , device = x .device
2983+ )
2984+
2985+ # Convert scales to blocked format.
2986+ x_list = []
2987+ w_list = []
2988+ x_blocked_scale_list = []
2989+ w_blocked_scale_list = []
2990+
2991+ def round_up (x : int , y : int ) -> int :
2992+ return ((x + y - 1 ) // y ) * y
2993+
2994+ for group_idx in range (G ):
2995+ # to_mxfp8 per group
2996+ prev_group_end_offset = (
2997+ 0 if group_idx == 0 else input_group_end_offsets [group_idx - 1 ]
2998+ )
2999+ curr_group_end_offset = input_group_end_offsets [group_idx ]
3000+ group_size = curr_group_end_offset - prev_group_end_offset
3001+ if group_size > 0 :
3002+ x_slice = x [
3003+ :, prev_group_end_offset :curr_group_end_offset
3004+ ].contiguous () # (M, K_group)
3005+ w_slice = w [
3006+ :, prev_group_end_offset :curr_group_end_offset
3007+ ].contiguous () # (N, K_group)
3008+ x_scale_slice , xq_slice = to_mxfp8 (
3009+ x_slice
3010+ ) # scale shape -> (M, K_group // 32)
3011+ w_scale_slice , wq_slice = to_mxfp8 (
3012+ w_slice
3013+ ) # scale shape -> (N, K_group // 32)
3014+ x_list .append (xq_slice )
3015+ w_list .append (wq_slice )
3016+
3017+ # Convert scales to blocked format.
3018+ x_scale_slice_blocked = _to_blocked (
3019+ x_scale_slice
3020+ ) # (round_up(M, 128), round_up(K_group//32, 4))
3021+ w_scale_slice_blocked = _to_blocked (
3022+ w_scale_slice
3023+ ) # (round_up(N, 128), round_up(K_group//32, 4))
3024+ x_blocked_scale_list .append (x_scale_slice_blocked )
3025+ w_blocked_scale_list .append (w_scale_slice_blocked )
3026+
3027+ # Assemble the full XQ and WQ
3028+ xq = torch .cat (x_list , dim = 1 ).contiguous ()
3029+ wq = torch .cat (w_list , dim = 1 ).contiguous ()
3030+
3031+ # Combine all XQ groups blocked scales into one tensor.
3032+ x_blocked_scales = torch .cat (x_blocked_scale_list , dim = 0 )
3033+ M_rounded = round_up (M , 128 )
3034+ x_blocked_scales = x_blocked_scales .reshape (M_rounded , - 1 )
3035+
3036+ # Combine all WQ groups blocked scales into one tensor.
3037+ w_blocked_scales = torch .cat (w_blocked_scale_list , dim = 0 )
3038+ N_rounded = round_up (N , 128 )
3039+ w_blocked_scales = w_blocked_scales .reshape (N_rounded , - 1 )
3040+ return xq , wq , x_blocked_scales , w_blocked_scales , input_group_end_offsets
3041+
3042+ def compute (self , xq , wq , x_scale , w_scale , input_group_end_offsets ):
3043+ return torch .ops .fbgemm .mx8mx8bf16_grouped_mm (
3044+ xq ,
3045+ wq .transpose (- 2 , - 1 ),
3046+ x_scale ,
3047+ w_scale ,
3048+ input_group_end_offsets ,
29403049 )
3050+
3051+ def quantize_and_compute (self , x , w ):
3052+ xq , wq , x_scale , w_scale , input_group_end_offsets = self .quantize (x , w )
29413053 return self .compute (
2942- xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding
3054+ xq ,
3055+ wq ,
3056+ x_scale ,
3057+ w_scale ,
3058+ input_group_end_offsets ,
29433059 )
29443060
29453061 @property
29463062 def name (self ) -> str :
2947- return "cutlass_mx8mx8bf16_grouped_stacked "
3063+ return "cutlass_mx8mx8bf16_grouped_mm_2d_2d "
29483064
29493065 @property
29503066 def hip (self ) -> bool :
0 commit comments