@@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
10941094 return return_and_correct_aliasing (
10951095 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
10961096 )
1097- if func is aten .clone .default :
1097+ elif func is aten .clone .default :
10981098 return return_and_correct_aliasing (
10991099 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
11001100 )
1101- if func is aten .t .default :
1101+ elif func is aten .t .default :
11021102 """we don't need to repack the weight and just rely on external
11031103 shape being changed and record the status of transpose/no-transpose
11041104 """
11051105 args [0 ].transposed = not args [0 ].transposed
11061106 return return_and_correct_aliasing (func , args , kwargs , args [0 ])
1107-
1108- raise NotImplementedError (
1109- f"Float8AQTLayout dispatch: attempting to run { func } , this is not supported"
1110- )
1107+ elif func is aten .slice .Tensor :
1108+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
1109+ if dim == 0 :
1110+ return return_and_correct_aliasing (
1111+ func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : aten .slice .Tensor (x , dim , start , end , step ))
1112+ )
1113+ elif dim == 1 :
1114+ assert len (self .scale .shape ) == 1 , f"slice dim==1 only works when len(scale.shape) == 1 currently, got: { self .scale .shape } "
1115+ return Float8AQTLayout (aten .slice .Tensor (self .float8_data , dim , start , end , step ), self .scale , None , self .layout_type )
1116+ else :
1117+ raise NotImplementedError (f"Float8AQTLayout dispatch: attempting to run { func } , with dim={ dim } , that is not supported" )
1118+ else :
1119+ raise NotImplementedError (
1120+ f"Float8AQTLayout dispatch: attempting to run { func } , this is not supported"
1121+ )
11111122
11121123 __torch_function__ = torch ._C ._disabled_torch_function_impl
11131124
@@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
16441655 use_fast_accum = scaled_mm_config .use_fast_accum ,
16451656 ).reshape (out_shape )
16461657
1658+ def _linear_fp_act_fp8_weight_check (
1659+ input_tensor : Union [torch .Tensor , AffineQuantizedTensor ],
1660+ weight_tensor : Union [torch .Tensor , AffineQuantizedTensor ],
1661+ bias : Optional [torch .Tensor ],
1662+ ) -> bool :
1663+ return (
1664+ # input is native float tensor
1665+ not is_traceable_wrapper_subclass (input_tensor ) and
1666+ input_tensor .is_floating_point () and
1667+ # weight is float8 quantized affine quantized tensor
1668+ isinstance (weight_tensor , AffineQuantizedTensor ) and
1669+ isinstance (weight_tensor .layout_type , Float8LayoutType )
1670+ and weight_tensor .layout_tensor .dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]
1671+ and (weight_tensor .shape == weight_tensor .block_size or _is_rowwise_scaled (weight_tensor ))
1672+ )
1673+
1674+ def _linear_fp_act_fp8_weight_impl (
1675+ input_tensor : torch .Tensor ,
1676+ weight_tensor : AffineQuantizedTensor ,
1677+ bias : Optional [torch .Tensor ],
1678+ ):
1679+ return torch .nn .functional .linear (input_tensor , weight_tensor .dequantize (), bias )
16471680
16481681def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
16491682 return (
@@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
16941727 (_linear_int8_act_int8_weight_semi_structured_sparse_check , _linear_int8_act_int8_weight_semi_structured_sparse_impl ),
16951728 (_linear_int8_act_int8_weight_block_sparse_check , _linear_int8_act_int8_weight_block_sparse_impl ),
16961729 (_linear_fp8_act_fp8_weight_check , _linear_fp8_act_fp8_weight_impl ),
1730+ (_linear_fp_act_fp8_weight_check , _linear_fp_act_fp8_weight_impl ),
16971731 (_linear_bf16_act_uint4_weight_check , _linear_bf16_act_uint4_weight_impl ),
16981732 (_linear_fp_act_int8_weight_check , _linear_fp_act_int8_weight_impl ),
16991733 (_linear_f16_act_floatx_weight_check , _linear_f16_act_floatx_weight_impl ),
0 commit comments