1010
1111from torchao .float8 .config import ScalingGranularity
1212from torchao .float8 .float8_utils import tensor_to_scale , to_fp8_saturated
13+ from torchao .prototype .scaled_grouped_mm .kernels .jagged_float8_scales import (
14+ triton_fp8_col_major_jagged_colwise_scales ,
15+ triton_fp8_row_major_jagged_rowwise_scales ,
16+ )
1317from torchao .prototype .scaled_grouped_mm .utils import _is_column_major
1418
1519
@@ -189,17 +193,18 @@ def backward(ctx, grad_output: torch.Tensor):
189193 # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
190194 # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
191195 grad_output_t_fp8_row_major , grad_output_t_scales = (
192- _to_2d_jagged_float8_tensor_rowwise (
196+ triton_fp8_row_major_jagged_rowwise_scales (
193197 grad_output_t_row_major ,
194198 offs ,
195- target_dtype = torch .float8_e4m3fn ,
199+ output_dtype = torch .float8_e4m3fn ,
196200 round_scales_to_power_of_2 = True ,
197201 )
198202 )
199- A_fp8_col_major , A_scales = _to_2d_jagged_float8_tensor_colwise (
203+
204+ A_fp8_col_major , A_scales = triton_fp8_col_major_jagged_colwise_scales (
200205 A_col_major ,
201206 offs ,
202- target_dtype = torch .float8_e4m3fn ,
207+ output_dtype = torch .float8_e4m3fn ,
203208 round_scales_to_power_of_2 = True ,
204209 )
205210
@@ -216,139 +221,3 @@ def backward(ctx, grad_output: torch.Tensor):
216221 use_fast_accum = True ,
217222 )
218223 return grad_A , grad_B .transpose (- 2 , - 1 ), None , None , None , None
219-
220-
221- def _to_2d_jagged_float8_tensor_colwise (
222- A_col_major : torch .Tensor ,
223- offs : torch .Tensor ,
224- target_dtype : torch .dtype = torch .float8_e4m3fn ,
225- round_scales_to_power_of_2 : bool = False ,
226- ) -> Tuple [torch .Tensor , torch .Tensor ]:
227- """
228- This function converts the 2D input tensor A to a jagged float8 tensor,
229- with scales computed along *logical columns* for each group individually,
230- where groups are determined based on the offsets.
231-
232- For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns.
233- (i.e., a tensor of (K,N) will have scales of shape (1,N).
234-
235- However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct
236- groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
237- along the logical columns and apply it to the entire tensor.
238-
239- Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results
240- in scales of shape (1,N * num_groups).
241-
242- Args:
243- A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
244-
245- Returns:
246- A tuple containing the jagged float8 tensor and the scales used for the conversion.
247- """
248- assert A_col_major .ndim == 2 , "A must be 2D"
249-
250- num_groups = offs .numel ()
251- A_fp8_col_major = torch .empty_like (A_col_major , dtype = target_dtype )
252- A_scales = torch .empty (
253- A_fp8_col_major .size (1 ) * num_groups ,
254- dtype = torch .float32 ,
255- device = A_fp8_col_major .device ,
256- )
257-
258- start_idx = 0
259- next_scale_idx = 0
260- for end_idx in offs .tolist ():
261- # Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
262- subtensor = A_col_major [start_idx :end_idx , :] # (local_group_size, K)
263-
264- # Compute local rowwise scales for this subtensor, which are along logical columns for the right operand.
265- subtensor_scales = tensor_to_scale (
266- subtensor ,
267- target_dtype ,
268- scaling_granularity = ScalingGranularity .AXISWISE ,
269- axiswise_dim = 0 ,
270- round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
271- )
272-
273- # Apply scales to subtensor and convert to float8.
274- tensor_scaled = subtensor .to (torch .float32 ) * subtensor_scales
275- float8_subtensor = to_fp8_saturated (tensor_scaled , target_dtype )
276-
277- # Store this portion of the resulting float8 tensor and scales.
278- A_fp8_col_major [start_idx :end_idx , :] = float8_subtensor
279- A_scales [next_scale_idx : next_scale_idx + subtensor_scales .numel ()] = (
280- subtensor_scales .squeeze ()
281- )
282-
283- # Update start index for next group.
284- start_idx = end_idx
285- next_scale_idx += subtensor_scales .numel ()
286-
287- return A_fp8_col_major , A_scales
288-
289-
290- def _to_2d_jagged_float8_tensor_rowwise (
291- x : torch .Tensor ,
292- offs : torch .Tensor ,
293- target_dtype : torch .dtype ,
294- round_scales_to_power_of_2 : bool = False ,
295- ) -> Tuple [torch .Tensor , torch .Tensor ]:
296- """
297- This function converts the 2D input tensor to a jagged float8 tensor,
298- with scales computed along *logical rows* for each group individually,
299- where groups are determined based on the offsets.
300-
301- For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows.
302- (i.e., a tensor of (M,K) will have scales of shape (M,1).
303-
304- However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct
305- groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
306- along the logical rows and apply it to the entire tensor.
307-
308- Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results
309- in scales of shape (M * num_groups, 1).
310-
311- Args:
312- A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
313-
314- Returns:
315- A tuple containing the jagged float8 tensor and the scales used for the conversion.
316- """
317- assert x .ndim == 2 , "input tensor must be 2D"
318-
319- num_groups = offs .numel ()
320- x_fp8 = torch .empty_like (x , dtype = target_dtype )
321- x_scales = torch .empty (
322- x_fp8 .size (0 ) * num_groups , dtype = torch .float32 , device = x_fp8 .device
323- )
324-
325- start_idx = 0
326- next_scale_idx = 0
327- for end_idx in offs .tolist ():
328- # Get the subtensor of A for this group, fetching all rows with the next group of rows.
329- subtensor = x [:, start_idx :end_idx ] # (M, local_group_size)
330-
331- # Compute local rowwise scales for this subtensor, which are along logical rows for the left operand.
332- subtensor_scales = tensor_to_scale (
333- subtensor ,
334- target_dtype ,
335- scaling_granularity = ScalingGranularity .AXISWISE ,
336- axiswise_dim = - 1 ,
337- round_scales_to_power_of_2 = round_scales_to_power_of_2 ,
338- )
339-
340- # Apply scales to subtensor and convert to float8.
341- tensor_scaled = subtensor .to (torch .float32 ) * subtensor_scales
342- float8_subtensor = to_fp8_saturated (tensor_scaled , target_dtype )
343-
344- # Store this portion of the resulting float8 tensor and scales.
345- x_fp8 [:, start_idx :end_idx ] = float8_subtensor
346- x_scales [next_scale_idx : next_scale_idx + subtensor_scales .numel ()] = (
347- subtensor_scales .squeeze ()
348- )
349-
350- # Update start index for next group.
351- start_idx = end_idx
352- next_scale_idx += subtensor_scales .numel ()
353-
354- return x_fp8 , x_scales
0 commit comments