@@ -306,8 +306,9 @@ def main(
306306 "fwd" ,
307307 "cast_only" ,
308308 "cast_with_to_blocked" ,
309+ "cast_only_dim0_dim1" ,
309310 )
310- ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
311+ ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1` "
311312 if mode_filter == "cast_only" :
312313 assert experiment_filter == "lowp" , "unsupported"
313314
@@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
395396 scale_blocked = to_blocked (x_mx ._scale_e8m0 .reshape (m , k // config .block_size ))
396397 return x_mx ._data , scale_blocked
397398
399+ # this function is used for cast_only_dim0_dim1
400+ def cast_only_dim0_dim1 (x_hp ):
401+ x_hp_t_c = x_hp .t ().contiguous ()
402+ x_mx_dim0 = MXTensor .to_mx (
403+ x_hp ,
404+ config .elem_dtype ,
405+ config .block_size ,
406+ gemm_kernel_choice = config .gemm_kernel_choice ,
407+ )
408+ x_mx_dim1 = MXTensor .to_mx (
409+ x_hp_t_c ,
410+ config .elem_dtype ,
411+ config .block_size ,
412+ gemm_kernel_choice = config .gemm_kernel_choice ,
413+ )
414+ return x_mx_dim0 , x_mx_dim1
415+
398416 print ("m_ref" , m_ref )
399417 print ("m_lowp" , m_lowp )
400418 print ("input_tensor.shape" , input_tensor .shape )
@@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
423441 elif mode_filter == "cast_with_to_blocked" :
424442 _input_tensor_mx , scale = cast_with_to_blocked (input_tensor )
425443 return
444+ elif mode_filter == "cast_only_dim0_dim1" :
445+ _input_tensor_mx_dim0 , _input_tensor_mx_dim1 = cast_only_dim0_dim1 (
446+ input_tensor ,
447+ )
448+ return
426449
427450 if enable_activation_checkpointing :
428451 out = checkpoint (m_lowp , x , use_reentrant = False , context_fn = context_fn )
@@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
437460 m_lowp = torch .compile (m_lowp , fullgraph = True )
438461 to_mx_func = torch .compile (to_mx_func , fullgraph = True )
439462 cast_with_to_blocked = torch .compile (cast_with_to_blocked , fullgraph = True )
463+ cast_only_dim0_dim1 = torch .compile (cast_only_dim0_dim1 , fullgraph = True )
440464
441465 # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
442466 # to populate triton kernel bandwidth further down in the script
0 commit comments