@@ -479,8 +479,6 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs(
479479 # Output block stride for the rearranged format
480480 BLOCK_ROWS , BLOCK_COLS = 128 , 4
481481 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS
482- num_row_blocks = padded_rows // BLOCK_ROWS
483- output_stride_per_col_of_blocks = output_stride_per_block * num_row_blocks
484482
485483 # We parallelize per group and per row block.
486484 # Cols per group is variable, so we just loop through col blocks for each group.
@@ -495,17 +493,17 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs(
495493 scales_tensor .stride (1 ),
496494 rows ,
497495 cols ,
496+ padded_rows ,
498497 num_groups ,
499498 # Original offsets (to read from)
500499 input_group_end_offsets ,
501500 # Output scales tensor and group offsets after padding (to write to)
502501 output .view (torch .uint8 ),
503502 output_group_start_offsets ,
504503 output_stride_per_block ,
505- output_stride_per_col_of_blocks ,
506504 BLOCK_ROWS = BLOCK_ROWS ,
507505 BLOCK_COLS = BLOCK_COLS ,
508- DEBUG = True ,
506+ DEBUG = False ,
509507 )
510508 return output
511509
@@ -517,12 +515,12 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
517515 scales_stride_dim1 ,
518516 scale_rows ,
519517 scale_cols ,
518+ padded_rows ,
520519 num_groups ,
521520 orig_offsets , # (num_groups,)
522521 output_scales_ptr ,
523522 output_scales_group_offsets , # (num_groups,)
524523 output_stride_per_block ,
525- output_stride_per_col_of_blocks ,
526524 BLOCK_ROWS : tl .constexpr ,
527525 BLOCK_COLS : tl .constexpr ,
528526 DEBUG : tl .constexpr = False ,
@@ -557,8 +555,9 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
557555
558556 # For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales.
559557 # We track how many col blocks we have iterated through.
558+ out_group_base_offset = output_group_start_col * padded_rows
560559 curr_input_start_col = input_group_start_col
561- curr_out_start_col_block = output_group_start_col // BLOCK_COLS
560+ curr_out_start_col_block = 0
562561 while curr_input_start_col < input_group_end_col :
563562 # Read block of input scales
564563 block_row_offs = block_row_pid * BLOCK_ROWS + row_offs
@@ -570,25 +569,21 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
570569 input_scales = tl .load (scales_ptr + block_offs , mask = mask , other = 0.0 )
571570 scales_flat = tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ))
572571
573- # Calculate block offset using provided output block stride
574- tgt_row_off = block_row_pid * output_stride_per_block
575- tgt_col_off = curr_out_start_col_block * output_stride_per_col_of_blocks
576-
577- output_block_offsets = tgt_row_off + tgt_col_off
578- if DEBUG :
579- tl .device_print ("\n block_row_pid: " , block_row_pid )
580- tl .device_print ("group_pid: " , group_pid )
581- tl .device_print ("tgt_row_block" , block_row_pid )
582- tl .device_print ("output_group_start_col: " , output_group_start_col )
583- tl .device_print ("tgt_col_block" , curr_out_start_col_block )
584- tl .device_print ("tgt_row_off: " , tgt_row_off )
585- tl .device_print ("tgt_col_off: " , tgt_col_off )
586- tl .device_print ("global_off:" , tgt_row_off + tgt_col_off )
587- tl .device_print ("writing: " , scales_flat )
572+ # Get offset within the group to add to the group's base offset
573+ num_cols_in_group = input_group_end_col - input_group_start_col
574+ num_col_blocks_in_group = tl .cdiv (num_cols_in_group , BLOCK_COLS )
575+ stride_per_row_of_blocks_in_group = (
576+ num_col_blocks_in_group * output_stride_per_block
577+ )
578+ offset_in_group = (
579+ block_row_pid * stride_per_row_of_blocks_in_group
580+ + curr_out_start_col_block * output_stride_per_block
581+ )
582+ final_offset = out_group_base_offset + offset_in_group
588583
589584 # Apply swizzling for write to gmem
590585 tl .store (
591- output_scales_ptr + output_block_offsets + dest_indices_flat ,
586+ output_scales_ptr + final_offset + dest_indices_flat ,
592587 scales_flat ,
593588 )
594589
0 commit comments