@@ -2916,78 +2916,194 @@ def cuda(self) -> bool:
29162916
29172917
29182918@register_quantize_op
2919- class MXFP8StackedGroupedGemm (QuantizeOpBase ):
2919+ class MXFP8GroupedGemm2d3d (QuantizeOpBase ):
29202920 """
2921- MXFP8 grouped matmul with blockwise scaling and stacked inputs .
2921+ MXFP8 grouped GEMM with 2D inputs and 3D weights .
29222922 """
29232923
29242924 def preprocess (self , x , w ):
2925- m_values = [i .shape [0 ] for i in x ]
2926- m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2925+ assert isinstance (x , list )
2926+ assert isinstance (w , list )
2927+ x = torch .cat (x , dim = 0 ).contiguous () # (G * M, K)
2928+ w = torch .stack (w , dim = 0 ).contiguous () # (G, N, K)
2929+ return x , w
2930+
2931+ def quantize (self , x , w ):
2932+ block_size = 32
2933+ G , N , K = w .shape
2934+ total_M = x .shape [0 ]
2935+ group_size = total_M // G
2936+ input_group_end_offsets = torch .arange (
2937+ group_size , total_M + 1 , group_size , dtype = torch .int32 , device = x .device
2938+ )
2939+
2940+ # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
2941+ # as they each used for independent gemm in the grouped gemm.
29272942 wq_list = []
29282943 w_scale_list = []
2929- for i in range (m_sizes . shape [ 0 ] ):
2944+ for i in range (G ):
29302945 w_scale , wq = to_mxfp8 (w [i ])
29312946 w_scale = _to_blocked (w_scale )
29322947 wq_list .append (wq )
29332948 w_scale_list .append (w_scale )
29342949 wq = torch .stack (wq_list , dim = 0 ).contiguous ()
29352950 w_scale = torch .stack (w_scale_list , dim = 0 ).contiguous ()
2936- return x , wq , w_scale , m_sizes
29372951
2938- def quantize ( self , x , wq , w_scale , m_sizes ):
2939- starting_row_after_padding_list = [ 0 ]
2952+ # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
2953+ # as they each used for independent gemm in the grouped gemm.
29402954 xq_list = []
29412955 x_scale_list = []
2942- for i in range (m_sizes .shape [0 ]):
2943- scale_slice = x [i ]
2944- if m_sizes [i ].item () != 0 :
2945- x_scale , xq = to_mxfp8 (scale_slice )
2956+ for i in range (G ):
2957+ prev_group_end = 0 if i == 0 else input_group_end_offsets [i - 1 ]
2958+ curr_group_end = input_group_end_offsets [i ]
2959+ group_size = curr_group_end - prev_group_end
2960+ if group_size > 0 :
2961+ x_slice = x [prev_group_end :curr_group_end , :]
2962+ x_scale , xq = to_mxfp8 (x_slice )
29462963 x_scale = _to_blocked (x_scale )
29472964 xq_list .append (xq )
29482965 x_scale_list .append (x_scale )
2949- starting_row_after_padding_list .append (
2950- starting_row_after_padding_list [i ]
2951- + x_scale .numel () // (x [0 ].shape [1 ] // 32 )
2952- )
2953- else :
2954- starting_row_after_padding_list .append (
2955- starting_row_after_padding_list [i ]
2956- )
29572966 xq = torch .cat (xq_list , dim = 0 ).contiguous ()
29582967 x_scale = torch .cat (x_scale_list , dim = 0 ).contiguous ()
2959- x_scale = x_scale .reshape (- 1 , x [ 0 ]. shape [ - 1 ] // 32 )
2968+ x_scale = x_scale .reshape (- 1 , K // block_size )
29602969 xq = xq .view (- 1 , xq .shape [- 1 ])
2961- return (
2970+ return xq , wq , x_scale , w_scale , input_group_end_offsets
2971+
2972+ def compute (self , xq , wq , x_scale , w_scale , input_group_end_offsets ):
2973+ return torch .ops .fbgemm .mx8mx8bf16_grouped_mm (
29622974 xq ,
2963- wq ,
2975+ wq . transpose ( - 2 , - 1 ) ,
29642976 x_scale ,
29652977 w_scale ,
2966- m_sizes ,
2967- torch .tensor (starting_row_after_padding_list , device = xq .device ),
2978+ input_group_end_offsets ,
29682979 )
29692980
2970- def compute (self , xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding ):
2971- return torch .ops .fbgemm .mx8mx8bf16_grouped_stacked (
2981+ def quantize_and_compute (self , x , w ):
2982+ xq , wq , x_scale , w_scale , input_group_end_offsets = self .quantize (x , w )
2983+ return self .compute (
29722984 xq ,
29732985 wq ,
29742986 x_scale ,
29752987 w_scale ,
2976- m_sizes ,
2977- starting_row_after_padding = starting_row_after_padding ,
2988+ input_group_end_offsets ,
29782989 )
29792990
2980- def quantize_and_compute (self , x , w ):
2981- xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding = self .quantize (
2982- x , w
2991+ @property
2992+ def name (self ) -> str :
2993+ return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
2994+
2995+ @property
2996+ def hip (self ) -> bool :
2997+ return False
2998+
2999+ @property
3000+ def cuda (self ) -> bool :
3001+ return True
3002+
3003+
3004+ @register_quantize_op
3005+ class MXFP8GroupedGemm2d2d (QuantizeOpBase ):
3006+ """
3007+ MXFP8 grouped GEMM with 2D inputs and 3D weights.
3008+ """
3009+
3010+ def preprocess (self , x , w ):
3011+ assert isinstance (x , list )
3012+ assert isinstance (w , list )
3013+ G = len (x )
3014+ x = torch .cat (x , dim = 1 ).contiguous () # (M, total_K)
3015+ w = torch .cat (w , dim = 1 ).contiguous () # (N, total_K)
3016+ return x , w , G
3017+
3018+ def quantize (self , x , w , G ):
3019+ # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
3020+ # where we use "K" as the contracting dim which has "G" groups.
3021+ M , total_K = x .shape
3022+ N , _ = w .shape
3023+ group_size = total_K // G
3024+ input_group_end_offsets = torch .arange (
3025+ group_size , total_K + 1 , group_size , dtype = torch .int32 , device = x .device
3026+ )
3027+
3028+ # Convert scales to blocked format.
3029+ x_list = []
3030+ w_list = []
3031+ x_blocked_scale_list = []
3032+ w_blocked_scale_list = []
3033+
3034+ def round_up (x : int , y : int ) -> int :
3035+ return ((x + y - 1 ) // y ) * y
3036+
3037+ for group_idx in range (G ):
3038+ # to_mxfp8 per group
3039+ prev_group_end_offset = (
3040+ 0 if group_idx == 0 else input_group_end_offsets [group_idx - 1 ]
3041+ )
3042+ curr_group_end_offset = input_group_end_offsets [group_idx ]
3043+ group_size = curr_group_end_offset - prev_group_end_offset
3044+ if group_size > 0 :
3045+ x_slice = x [
3046+ :, prev_group_end_offset :curr_group_end_offset
3047+ ].contiguous () # (M, K_group)
3048+ w_slice = w [
3049+ :, prev_group_end_offset :curr_group_end_offset
3050+ ].contiguous () # (N, K_group)
3051+ x_scale_slice , xq_slice = to_mxfp8 (
3052+ x_slice
3053+ ) # scale shape -> (M, K_group // 32)
3054+ w_scale_slice , wq_slice = to_mxfp8 (
3055+ w_slice
3056+ ) # scale shape -> (N, K_group // 32)
3057+ x_list .append (xq_slice )
3058+ w_list .append (wq_slice )
3059+
3060+ # Convert scales to blocked format.
3061+ x_scale_slice_blocked = _to_blocked (
3062+ x_scale_slice
3063+ ) # (round_up(M, 128), round_up(K_group//32, 4))
3064+ w_scale_slice_blocked = _to_blocked (
3065+ w_scale_slice
3066+ ) # (round_up(N, 128), round_up(K_group//32, 4))
3067+ x_blocked_scale_list .append (x_scale_slice_blocked )
3068+ w_blocked_scale_list .append (w_scale_slice_blocked )
3069+
3070+ # Assemble the full XQ and WQ
3071+ xq = torch .cat (x_list , dim = 1 ).contiguous ()
3072+ wq = torch .cat (w_list , dim = 1 ).contiguous ()
3073+
3074+ # Combine all XQ groups blocked scales into one tensor.
3075+ x_blocked_scales = torch .cat (x_blocked_scale_list , dim = 0 )
3076+ M_rounded = round_up (M , 128 )
3077+ x_blocked_scales = x_blocked_scales .reshape (M_rounded , - 1 )
3078+
3079+ # Combine all WQ groups blocked scales into one tensor.
3080+ w_blocked_scales = torch .cat (w_blocked_scale_list , dim = 0 )
3081+ N_rounded = round_up (N , 128 )
3082+ w_blocked_scales = w_blocked_scales .reshape (N_rounded , - 1 )
3083+ return xq , wq , x_blocked_scales , w_blocked_scales , input_group_end_offsets
3084+
3085+ def compute (self , xq , wq , x_scale , w_scale , input_group_end_offsets ):
3086+ return torch .ops .fbgemm .mx8mx8bf16_grouped_mm (
3087+ xq ,
3088+ wq .transpose (- 2 , - 1 ),
3089+ x_scale ,
3090+ w_scale ,
3091+ input_group_end_offsets ,
29833092 )
3093+
3094+ def quantize_and_compute (self , x , w ):
3095+ xq , wq , x_scale , w_scale , input_group_end_offsets = self .quantize (x , w )
29843096 return self .compute (
2985- xq , wq , x_scale , w_scale , m_sizes , starting_row_after_padding
3097+ xq ,
3098+ wq ,
3099+ x_scale ,
3100+ w_scale ,
3101+ input_group_end_offsets ,
29863102 )
29873103
29883104 @property
29893105 def name (self ) -> str :
2990- return "cutlass_mx8mx8bf16_grouped_stacked "
3106+ return "cutlass_mx8mx8bf16_grouped_mm_2d_2d "
29913107
29923108 @property
29933109 def hip (self ) -> bool :
0 commit comments