@@ -1383,6 +1383,124 @@ def triton_to_mxfp8_dim1_reference(
13831383 scale_e8m0_dim1 ,
13841384 )
13851385
1386+ @triton .jit
1387+ def triton_scale_swizzle (
1388+ scale_ptr ,
1389+ scale_rows ,
1390+ scale_cols ,
1391+ output_ptr ,
1392+ input_row_stride ,
1393+ output_block_stride ,
1394+ BLOCK_ROWS : tl .constexpr ,
1395+ BLOCK_COLS : tl .constexpr ,
1396+ ):
1397+ """
1398+ Rearranges tensor data from row-major to block-scaled swizzle format.
1399+
1400+ Args:
1401+ scale_ptr: Pointer to the input scale tensor
1402+ scale_rows: Number of rows in the scale tensor
1403+ scale_cols: Number of columns in the scale tensor
1404+ output_ptr: Pointer to the output tensor
1405+ input_row_stride: Stride between rows in the input tensor
1406+ output_block_stride: Stride between blocks in the output tensor
1407+ BLOCK_ROWS: Number of rows in a tile (compile-time constant)
1408+ BLOCK_COLS: Number of columns in a tile (compile-time constant)
1409+ """
1410+ pid_row = tl .program_id (0 )
1411+ pid_col = tl .program_id (1 )
1412+
1413+ rows = tl .arange (0 , BLOCK_ROWS )[:, None ]
1414+ cols = tl .arange (0 , BLOCK_COLS )[None , :]
1415+
1416+ # Calculate starting row and column for this tile
1417+ start_row = pid_row * BLOCK_ROWS
1418+ start_col = pid_col * BLOCK_COLS
1419+ global_rows = start_row + rows
1420+ global_cols = start_col + cols
1421+
1422+ mask = (global_rows < scale_rows ) & (global_cols < scale_cols )
1423+
1424+ input_scales = tl .load (
1425+ scale_ptr + global_rows * input_row_stride + global_cols ,
1426+ mask = mask ,
1427+ other = 0.0 ,
1428+ )
1429+
1430+ r_div_32 = rows // 32
1431+ r_mod_32 = rows % 32
1432+
1433+ # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
1434+ dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
1435+
1436+ # Flatten
1437+ dest_indices_flat = tl .reshape (dest_indices , (BLOCK_ROWS * BLOCK_COLS ))
1438+ scales_flat = tl .reshape (input_scales , (BLOCK_ROWS * BLOCK_COLS ))
1439+
1440+ # Calculate block offset using provided output block stride
1441+ LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
1442+ block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride )
1443+
1444+ tl .store (
1445+ output_ptr + block_offset + dest_indices_flat ,
1446+ scales_flat ,
1447+ )
1448+
1449+ def triton_mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1450+ """
1451+ Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1452+
1453+ This format is suitable for Tmem as described in NVIDIA documentation:
1454+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
1455+
1456+ Args:
1457+ scale_tensor: Input tensor in row-major format with 8-bit elements
1458+
1459+ Returns:
1460+ Rearranged tensor in block-scaled swizzle format
1461+ """
1462+ assert scale_tensor .element_size () == 1 , (
1463+ "Expected element size to be 1 byte (8 bits)"
1464+ )
1465+ assert scale_tensor .is_contiguous (), "Input tensor must be contiguous"
1466+
1467+ rows , cols = scale_tensor .shape
1468+
1469+ # Calculate blocks needed
1470+ n_row_blocks = triton .cdiv (rows , 128 )
1471+ n_col_blocks = triton .cdiv (cols , 4 )
1472+ padded_rows = n_row_blocks * 128
1473+ padded_cols = n_col_blocks * 4
1474+
1475+ out = scale_tensor .new_empty ((padded_rows , padded_cols ))
1476+
1477+ # Input stride (for row-major format)
1478+ input_row_stride = cols
1479+
1480+ # We probably want handle multiple blocks per tile but for now keep it simple
1481+ BLOCK_ROWS , BLOCK_COLS = 128 , 4
1482+
1483+ # Output block stride for the rearranged format
1484+ output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS )
1485+
1486+ grid = lambda META : (
1487+ triton .cdiv (padded_rows , BLOCK_ROWS ),
1488+ triton .cdiv (padded_cols , BLOCK_COLS ),
1489+ )
1490+
1491+ wrap_triton (triton_scale_swizzle )[grid ](
1492+ scale_tensor .view (torch .uint8 ),
1493+ rows ,
1494+ cols ,
1495+ out .view (torch .uint8 ),
1496+ input_row_stride ,
1497+ output_block_stride ,
1498+ BLOCK_ROWS = BLOCK_ROWS ,
1499+ BLOCK_COLS = BLOCK_COLS ,
1500+ )
1501+
1502+ return out
1503+
13861504else :
13871505
13881506 def triton_to_mxfp8_dim1 (
@@ -1394,3 +1512,6 @@ def triton_to_mxfp8_dim1_reference(
13941512 x_hp : torch .Tensor , block_size
13951513 ) -> Tuple [torch .Tensor , torch .Tensor ]:
13961514 raise AssertionError ("needs torch version 2.8+ and triton" )
1515+
1516+ def triton_mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
1517+ raise AssertionError ("needs torch version 2.8+ and triton" )
0 commit comments