Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ def main(
"fwd",
"cast_only",
"cast_with_to_blocked",
"cast_only_dim0_dim1",
)
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`"
if mode_filter == "cast_only":
assert experiment_filter == "lowp", "unsupported"

Expand Down Expand Up @@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
return x_mx._data, scale_blocked

# this function is used for cast_only_dim0_dim1
def cast_only_dim0_dim1(x_hp):
x_hp_t_c = x_hp.t().contiguous()
x_mx_dim0 = MXTensor.to_mx(
x_hp,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
x_mx_dim1 = MXTensor.to_mx(
x_hp_t_c,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
return x_mx_dim0, x_mx_dim1

print("m_ref", m_ref)
print("m_lowp", m_lowp)
print("input_tensor.shape", input_tensor.shape)
Expand Down Expand Up @@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
elif mode_filter == "cast_with_to_blocked":
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
return
elif mode_filter == "cast_only_dim0_dim1":
_input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1(
input_tensor,
)
return

if enable_activation_checkpointing:
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
Expand All @@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
m_lowp = torch.compile(m_lowp, fullgraph=True)
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True)

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
Expand Down
Loading