File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -1448,6 +1448,7 @@ def triton_scale_swizzle(
14481448 scales_flat ,
14491449 )
14501450
1451+ @torch .library .custom_op ("torchao::triton_mx_block_rearrange" , mutates_args = ())
14511452 def triton_mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
14521453 """
14531454 Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
@@ -1716,6 +1717,15 @@ def _(x, per_tensor_scale=None):
17161717 xq = torch .empty (M , N // 2 , device = x .device , dtype = torch .uint8 )
17171718 return scales , xq
17181719
1720+ @triton_mx_block_rearrange .register_fake
1721+ def _ (scale_tensor ):
1722+ rows , cols = scale_tensor .shape
1723+ n_row_blocks = triton .cdiv (rows , 128 )
1724+ n_col_blocks = triton .cdiv (cols , 4 )
1725+ padded_rows = n_row_blocks * 128
1726+ padded_cols = n_col_blocks * 4
1727+
1728+ return scale_tensor .new_empty ((padded_rows , padded_cols ))
17191729else :
17201730
17211731 def triton_to_mxfp8_dim1 (
Original file line number Diff line number Diff line change @@ -15,7 +15,7 @@ def ceil_div(a, b):
1515 return (a + b - 1 ) // b
1616
1717
18- def to_blocked (input_matrix , use_triton_kernel : bool = True ) -> Tensor :
18+ def to_blocked (input_matrix , use_triton_kernel : bool = False ) -> Tensor :
1919 """
2020 Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
2121
You can’t perform that action at this time.
0 commit comments