Skip to content

Commit 73eec2e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 00fb1b8 + 1ab1b77 commit 73eec2e

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

benchmarks/float8/profile_lowp_training.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ def main(
306306
"fwd",
307307
"cast_only",
308308
"cast_with_to_blocked",
309+
"cast_only_dim0_dim1",
309310
)
310-
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
311+
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`"
311312
if mode_filter == "cast_only":
312313
assert experiment_filter == "lowp", "unsupported"
313314

@@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
395396
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
396397
return x_mx._data, scale_blocked
397398

399+
# this function is used for cast_only_dim0_dim1
400+
def cast_only_dim0_dim1(x_hp):
401+
x_hp_t_c = x_hp.t().contiguous()
402+
x_mx_dim0 = MXTensor.to_mx(
403+
x_hp,
404+
config.elem_dtype,
405+
config.block_size,
406+
gemm_kernel_choice=config.gemm_kernel_choice,
407+
)
408+
x_mx_dim1 = MXTensor.to_mx(
409+
x_hp_t_c,
410+
config.elem_dtype,
411+
config.block_size,
412+
gemm_kernel_choice=config.gemm_kernel_choice,
413+
)
414+
return x_mx_dim0, x_mx_dim1
415+
398416
print("m_ref", m_ref)
399417
print("m_lowp", m_lowp)
400418
print("input_tensor.shape", input_tensor.shape)
@@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
423441
elif mode_filter == "cast_with_to_blocked":
424442
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
425443
return
444+
elif mode_filter == "cast_only_dim0_dim1":
445+
_input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1(
446+
input_tensor,
447+
)
448+
return
426449

427450
if enable_activation_checkpointing:
428451
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
@@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
437460
m_lowp = torch.compile(m_lowp, fullgraph=True)
438461
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
439462
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
463+
cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True)
440464

441465
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
442466
# to populate triton kernel bandwidth further down in the script

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pytest
88
import torch
9+
from torch._inductor.utils import run_and_get_code
10+
from torch.testing import FileCheck
911

1012
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
1113
from torchao.prototype.mx_formats.constants import (
@@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
284286
use_fp4_custom_triton_dequant_kernel,
285287
)
286288
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
289+
290+
291+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
292+
@pytest.mark.skipif(
293+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
294+
)
295+
@pytest.mark.skipif(
296+
not is_sm_at_least_89(),
297+
reason="float8 in triton requires CUDA capability 8.9 or greater",
298+
)
299+
def test_to_mx_inductor_single_kernel():
300+
"""
301+
Verify that inductor can fuse the cast of a high precision tensor to mx
302+
into a single kernel
303+
"""
304+
# TODO(future PR): add fp4 and fp6 here
305+
# TODO(#1773): add swizzled scale format here
306+
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
307+
block_size = 32
308+
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
309+
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
310+
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,25 @@ def to_mx(
205205
data_lp = torch.clamp(
206206
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
207207
)
208-
data_lp = data_lp.reshape(orig_shape)
209208

210209
# cast to target dtype
211210
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
212211
data_lp = data_lp.to(elem_dtype)
212+
# need to reshape at the end to help inductor fuse things
213+
data_lp = data_lp.reshape(orig_shape)
213214
elif elem_dtype == DTYPE_FP6_E2M3:
214215
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
216+
# need to reshape at the end to help inductor fuse things
217+
data_lp = data_lp.reshape(orig_shape)
215218
elif elem_dtype == DTYPE_FP6_E3M2:
216219
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
220+
# need to reshape at the end to help inductor fuse things
221+
data_lp = data_lp.reshape(orig_shape)
217222
elif elem_dtype == DTYPE_FP4:
223+
# can't reshape at the end without handling it in the packing code,
224+
# punt until later since we'll need to rethink the torch.compile
225+
# approach for fp4x2 in any case
226+
data_lp = data_lp.reshape(orig_shape)
218227
data_lp = f32_to_f4_unpacked(data_lp)
219228
data_lp = pack_uint4(data_lp)
220229
else:

0 commit comments

Comments
 (0)