1616import triton
1717import triton .language as tl
1818
19- from torchao .prototype .moe_training .utils import _is_column_major
20-
2119EPS = 1e-12
2220
2321FP8_DTYPE_MAP = {
3331 torch .float64 : tl .float64 ,
3432}
3533
36- block_sizes = [128 , 256 ]
34+ block_sizes = [1 , 16 , 32 , 64 ]
35+ block_sizes_iter = [32 , 64 , 128 , 256 ]
36+ num_warps = [1 , 4 ]
37+ num_stages = [2 , 3 ]
3738kernel_configs_2D = [
3839 triton .Config (
39- {"BLOCK_SIZE_ROWS" : block_size_rows , "BLOCK_SIZE_COLS" : block_size_cols }
40+ {"BLOCK_SIZE" : block_size , "BLOCK_SIZE_ITER" : block_size_iter },
41+ num_warps = warps ,
42+ num_stages = stages ,
4043 )
41- for block_size_rows in block_sizes
42- for block_size_cols in block_sizes
44+ for block_size in block_sizes
45+ for block_size_iter in block_sizes_iter
46+ for warps in num_warps
47+ for stages in num_stages
4348]
4449
4550from torch .library import triton_op , wrap_triton
@@ -68,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales(
6873 - jagged rowwise scales (i.e., rowwise scales for each group)
6974 """
7075 assert hp_tensor .ndim == 2 , "input tensor must be 2D"
71- assert hp_tensor .is_contiguous (), "input tensor must be contiguous"
7276
7377 num_elements = hp_tensor .numel ()
7478 tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
@@ -81,16 +85,14 @@ def triton_fp8_row_major_jagged_rowwise_scales(
8185 n_groups = offsets .numel ()
8286
8387 # allocate on-device buffers for output and scales
84- output_buffer = torch .empty_like (
85- hp_tensor , dtype = output_dtype , device = hp_tensor .device
86- )
88+ output_buffer = torch .empty ((m , k ), dtype = output_dtype , device = hp_tensor .device )
8789 scales_buffer = torch .empty (
8890 (m * n_groups ), dtype = torch .float32 , device = hp_tensor .device
8991 )
9092
9193 # parallelize across rows and groups (offsets)
9294 grid = lambda meta : (
93- triton .cdiv (m , meta ["BLOCK_SIZE_ROWS " ]),
95+ triton .cdiv (m , meta ["BLOCK_SIZE " ]),
9496 offsets .numel (),
9597 )
9698 wrap_triton (_triton_fp8_row_major_jagged_rowwise_scales )[grid ](
@@ -115,7 +117,13 @@ def triton_fp8_row_major_jagged_rowwise_scales(
115117 return output_buffer , scales_buffer
116118
117119
118- @triton .autotune (configs = kernel_configs_2D , key = ["num_elements" ])
120+ # This kernel is used on grad_output.t() which has shape (K, M),
121+ # before the calculation `grad_B = grad_output_t @ input`.
122+ # However, in this code, we use the conventional dim names (M, K)
123+ # so the kernel is easily interpretable in a standalone fasion.
124+ # The tokens per expert will vary per iteration, so don't want
125+ # to recompile on `token` dim (K, in this case) changes.
126+ @triton .autotune (configs = kernel_configs_2D , key = ["M" ])
119127@triton .jit
120128def _triton_fp8_row_major_jagged_rowwise_scales (
121129 input_ptr ,
@@ -134,8 +142,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
134142 input_dtype : tl .constexpr ,
135143 output_dtype : tl .constexpr ,
136144 round_scales_to_power_of_2 : tl .constexpr ,
137- BLOCK_SIZE_ROWS : tl .constexpr ,
138- BLOCK_SIZE_COLS : tl .constexpr ,
145+ BLOCK_SIZE : tl .constexpr ,
146+ BLOCK_SIZE_ITER : tl .constexpr ,
139147 EPS : tl .constexpr ,
140148):
141149 # parallel across rows and groups (offsets)
@@ -147,12 +155,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
147155 offsets_ptr + offset_idx - 1 , mask = offset_idx > 0 , other = 0
148156 )
149157 group_col_end_idx = tl .load (offsets_ptr + offset_idx )
150- block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl .arange (0 , BLOCK_SIZE_ROWS )
158+ block_row_offs = block_row_id * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
151159
152160 # compute rowwise amaxes for this group
153- amax_buffer = tl .zeros ((BLOCK_SIZE_ROWS ,), dtype = input_dtype )
154- for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_COLS ):
155- block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_COLS )
161+ amax_buffer = tl .zeros ((BLOCK_SIZE ,), dtype = input_dtype )
162+ for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_ITER ):
163+ block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
156164 block_offs = (
157165 block_row_offs [:, None ] * stride_input_row
158166 + block_col_offs [None , :] * stride_input_col
@@ -180,12 +188,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
180188 # store rowwise scales for each group in contiguous memory:
181189 # [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
182190 scales_offs = block_row_offs + (M * offset_idx )
183- scales_mask = tl .arange (0 , BLOCK_SIZE_ROWS ) < M
191+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < M
184192 tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
185193
186194 # perform float8 conversion for this group
187- for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_COLS ):
188- block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_COLS )
195+ for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_ITER ):
196+ block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
189197 block_offs = (
190198 block_row_offs [:, None ] * stride_input_row
191199 + block_col_offs [None , :] * stride_input_col
@@ -230,7 +238,6 @@ def triton_fp8_col_major_jagged_colwise_scales(
230238 - jagged column-wise scales (i.e., column-wise scales for each group)
231239 """
232240 assert hp_tensor .ndim == 2 , "input tensor must be 2D"
233- assert _is_column_major (hp_tensor ), "input tensor must be column-major"
234241
235242 num_elements = hp_tensor .numel ()
236243 tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
@@ -242,17 +249,18 @@ def triton_fp8_col_major_jagged_colwise_scales(
242249 k , n = hp_tensor .shape
243250 n_groups = offsets .numel ()
244251
245- # allocate on-device buffers for output and scales
252+ # Output buffer in column major
246253 output_buffer = torch .empty_like (
247254 hp_tensor , dtype = output_dtype , device = hp_tensor .device
248- )
255+ ).as_strided (hp_tensor .size (), (1 , k ))
256+
249257 scales_buffer = torch .empty (
250258 (n * n_groups ), dtype = torch .float32 , device = hp_tensor .device
251259 )
252260
253261 # parallelize across columns and groups (offsets)
254262 grid = lambda meta : (
255- triton .cdiv (n , meta ["BLOCK_SIZE_COLS " ]),
263+ triton .cdiv (n , meta ["BLOCK_SIZE " ]),
256264 offsets .numel (),
257265 )
258266 wrap_triton (_triton_fp8_col_major_jagged_colwise_scales )[grid ](
@@ -277,7 +285,11 @@ def triton_fp8_col_major_jagged_colwise_scales(
277285 return output_buffer , scales_buffer
278286
279287
280- @triton .autotune (configs = kernel_configs_2D , key = ["num_elements" ])
288+ # This kernel is used on `input` which has shape (M, K),
289+ # before the calculation `grad_B = grad_output_t @ input`.
290+ # The tokens per expert will vary per iteration, so don't want
291+ # to recompile on `token` dim (M) changes.
292+ @triton .autotune (configs = kernel_configs_2D , key = ["K" ])
281293@triton .jit
282294def _triton_fp8_col_major_jagged_colwise_scales (
283295 input_ptr ,
@@ -296,8 +308,8 @@ def _triton_fp8_col_major_jagged_colwise_scales(
296308 input_dtype : tl .constexpr ,
297309 output_dtype : tl .constexpr ,
298310 round_scales_to_power_of_2 : tl .constexpr ,
299- BLOCK_SIZE_ROWS : tl .constexpr ,
300- BLOCK_SIZE_COLS : tl .constexpr ,
311+ BLOCK_SIZE : tl .constexpr ,
312+ BLOCK_SIZE_ITER : tl .constexpr ,
301313 EPS : tl .constexpr ,
302314):
303315 # parallel across columns and groups (offsets)
@@ -309,12 +321,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
309321 offsets_ptr + offset_idx - 1 , mask = offset_idx > 0 , other = 0
310322 )
311323 group_row_end_idx = tl .load (offsets_ptr + offset_idx )
312- block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl .arange (0 , BLOCK_SIZE_COLS )
324+ block_col_offs = block_col_id * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
313325
314326 # compute colwise amaxes for this group
315- amax_buffer = tl .zeros ((BLOCK_SIZE_COLS ,), dtype = input_dtype )
316- for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ROWS ):
317- block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ROWS )
327+ amax_buffer = tl .zeros ((BLOCK_SIZE ,), dtype = input_dtype )
328+ for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ITER ):
329+ block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
318330 block_offs = (
319331 block_row_offs [:, None ] * stride_input_row
320332 + block_col_offs [None , :] * stride_input_col
@@ -343,12 +355,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
343355 # [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
344356 # note: input tensor is in col-major memory layout.
345357 scales_offs = block_col_offs + (N * offset_idx )
346- scales_mask = tl .arange (0 , BLOCK_SIZE_COLS ) < N
358+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < N
347359 tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
348360
349361 # perform float8 conversion for this group
350- for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ROWS ):
351- block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ROWS )
362+ for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ITER ):
363+ block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
352364 block_offs = (
353365 block_row_offs [:, None ] * stride_input_row
354366 + block_col_offs [None , :] * stride_input_col
0 commit comments