@@ -1107,12 +1107,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
11071107 elif func is aten .slice .Tensor :
11081108 self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
11091109 if dim == 0 :
1110+ #TODO: scale replecation should be dependent on block size
1111+ if self .scale .ndim == 1 :
1112+ print ("slice for dim 0, scale is 1" )
1113+ return return_and_correct_aliasing (
1114+ func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : aten .slice .Tensor (x , dim , start , end , step ))
1115+ )
1116+ else :
1117+ print ("slice for dim 0, scale != 1" )
1118+ return return_and_correct_aliasing (
1119+ func , args , kwargs , Float8AQTTensorImpl (aten .slice .Tensor (self .float8_data , dim , start , end , step ), self .scale , None , self ._layout )
1120+ )
1121+ elif dim == 1 :
1122+ print ("slice for dim 1" )
11101123 return return_and_correct_aliasing (
1111- func , args , kwargs , args [ 0 ]. _apply_fn_to_data ( lambda x : aten .slice .Tensor (x , dim , start , end , step ))
1124+ func , args , kwargs , Float8AQTTensorImpl ( aten .slice .Tensor (self . float8_data , dim , start , end , step ), self . scale , None , self . _layout )
11121125 )
1113- elif dim == 1 :
1114- assert len (self .scale .shape ) == 1 , f"slice dim==1 only works when len(scale.shape) == 1 currently, got: { self .scale .shape } "
1115- return Float8AQTTensorImpl (aten .slice .Tensor (self .float8_data , dim , start , end , step ), self .scale , None , self ._layout )
11161126 else :
11171127 raise NotImplementedError (f"Float8AQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported" )
11181128 else :
@@ -1644,6 +1654,11 @@ def _linear_fp8_act_fp8_weight_impl(
16441654 # Preprocess data
16451655 inpt_data , w_data = preprocess_data (inpt_data , w_data .T , scaled_mm_config )
16461656
1657+
1658+ print (f"out_shape: { out_shape } " )
1659+ print (f"input_tensor: { input_tensor .shape } , weight_tensor: { weight_tensor .shape } " )
1660+ print (f"inpt_data: { inpt_data .shape } , w_data: { w_data .shape } " )
1661+
16471662 # Perform the computation
16481663 return addmm_float8_unwrapped_inference (
16491664 inpt_data ,
@@ -1858,12 +1873,17 @@ def _(func, types, args, kwargs):
18581873 end = self .shape [dim ]
18591874 shape = list (self .shape )
18601875 shape [dim ] = end - start
1876+ print (f"Shape: { self .shape } -> { shape } " )
1877+ print (f"Block size: { self .block_size } -> { self .block_size } " )
1878+ print (f"end: { end } , start: { start } " )
18611879 block_size = self .block_size
18621880 assert len (block_size ) == 2 , f"Slice only works for 2d block_size right now, got: { block_size } "
18631881 # with slice, some shape dimension might be smaller than block_size dimension, so
18641882 # we need to make sure there is no overflow
18651883 block_size = (min (shape [0 ], block_size [0 ]), min (shape [1 ], block_size [1 ]))
18661884 new = self .__class__ (aten .slice .Tensor (self .tensor_impl , dim , start , end , step ), block_size , shape , self .quant_min , self .quant_max , self .zero_point_domain , dtype = self .dtype , strides = self .stride ())
1885+ print (f"slice (Outer tensor shape): { self .shape } -> { new .shape } " )
1886+ print (f"slice (Inner data shape): { self .tensor_impl .float8_data .shape } -> { new .tensor_impl .float8_data .shape } " )
18671887 return return_and_correct_aliasing (func , args , kwargs , new )
18681888
18691889# this is needed for DTensor.from_local() and for flattening tensor
0 commit comments