diff --git a/benchmarks/float8/profile_lowp_training.py b/benchmarks/float8/profile_lowp_training.py index dd629e7f95..d4a3079360 100644 --- a/benchmarks/float8/profile_lowp_training.py +++ b/benchmarks/float8/profile_lowp_training.py @@ -306,8 +306,9 @@ def main( "fwd", "cast_only", "cast_with_to_blocked", + "cast_only_dim0_dim1", ) - ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`" + ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`" if mode_filter == "cast_only": assert experiment_filter == "lowp", "unsupported" @@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp): scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size)) return x_mx._data, scale_blocked + # this function is used for cast_only_dim0_dim1 + def cast_only_dim0_dim1(x_hp): + x_hp_t_c = x_hp.t().contiguous() + x_mx_dim0 = MXTensor.to_mx( + x_hp, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + x_mx_dim1 = MXTensor.to_mx( + x_hp_t_c, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + return x_mx_dim0, x_mx_dim1 + print("m_ref", m_ref) print("m_lowp", m_lowp) print("input_tensor.shape", input_tensor.shape) @@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x): elif mode_filter == "cast_with_to_blocked": _input_tensor_mx, scale = cast_with_to_blocked(input_tensor) return + elif mode_filter == "cast_only_dim0_dim1": + _input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1( + input_tensor, + ) + return if enable_activation_checkpointing: out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn) @@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x): m_lowp = torch.compile(m_lowp, fullgraph=True) to_mx_func = torch.compile(to_mx_func, fullgraph=True) cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True) + cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script