2727class FbgemmFp8Tensor (TorchAOBaseTensor ):
2828 """
2929 TODO: needs padding for cutlass kernels
30+ Args:
31+ data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
32+ float8_data: (batch_size, output_channel, input_channel)
33+ scale: (batch_size, output_channel) (since it's per row quantization)
34+ data_to_scale_dim: {0: 0, 1: 1}
3035 """
3136
3237 tensor_data_attrs = ["float8_data" , "scale" , "activation_scale_ub" ]
33- tensor_attributes = ["dtype" ]
38+ tensor_attributes = ["data_to_scale_dim" , " dtype" ]
3439
35- def __new__ (cls , float8_data , scale , activation_scale_ub , dtype ):
40+ def __new__ (cls , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype ):
3641 shape = float8_data .shape
3742 kwargs = {}
3843 kwargs ["device" ] = float8_data .device
3944 kwargs ["dtype" ] = dtype
4045 kwargs ["requires_grad" ] = False
4146 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
4247
43- def __init__ (self , float8_data , scale , activation_scale_ub , dtype ):
48+ def __init__ (self , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype ):
4449 self .float8_data = float8_data
4550 self .scale = scale
51+ self .data_to_scale_dim = data_to_scale_dim
4652 self .activation_scale_ub = activation_scale_ub
4753
4854 def __tensor_flatten__ (self ):
@@ -68,12 +74,12 @@ def _apply_fn_to_data(self, fn):
6874 def __repr__ (self ):
6975 return (
7076 f"{ self .__class__ .__name__ } (weight={ self .float8_data } , scale={ self .scale } , "
71- f"activation_scale_ub={ self .activation_scale_ub } , "
77+ f"activation_scale_ub={ self .activation_scale_ub } , data_to_scale_dim= { self . data_to_scale_dim } , "
7278 f"shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
7379 )
7480
7581 def _quantization_type (self ):
76- return f"shape={ self .shape } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
82+ return f"shape={ self .shape } , data_to_scale_dim= { self . data_to_scale_dim } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
7783
7884 def to (self , * args , ** kwargs ):
7985 kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -82,9 +88,57 @@ def to(self, *args, **kwargs):
8288 self .float8_data .to (device ),
8389 self .scale .to (device ),
8490 self .activation_scale_ub .to (device ),
91+ self .data_to_scale_dim ,
8592 self .dtype ,
8693 )
8794
95+ def _transpose_and_reshape (self ):
96+ """This is added for resharding support, since the resharding logic for the model we are
97+ working with only support 2D
98+ """
99+ assert len (self .shape ) == 3 , f"Only expected to be used when the Tensor is 3D, got { len (self .shape )} "
100+ dim0 , dim1 , dim2 = self .shape
101+ # because we first transpose the weight before quantization, we'll recover the original shape
102+ # by swapping dim1 and dim2
103+ original_shape = (dim0 , dim2 , dim1 )
104+ # we must save this as 2D in the state dict, since loading code expects 2D weights
105+ new_shape = (- 1 , original_shape [- 1 ])
106+ float8_data = self .float8_data
107+ float8_data = float8_data .transpose (1 , 2 ).reshape (* new_shape ).contiguous ()
108+ data_to_scale_dim = {
109+ 0 : 0 ,
110+ 1 : 1
111+ }
112+ return self .__class__ (
113+ float8_data ,
114+ self .scale ,
115+ self .activation_scale_ub ,
116+ data_to_scale_dim ,
117+ self .dtype
118+ )
119+
120+ def _unflatten (self , num_experts ):
121+ """This is added for resharding support, since the resharding logic for the model we are
122+ working with only support 2D
123+ """
124+ float8_data = self .float8_data
125+ dim0 , dim1 = self .shape
126+ float8_data = float8_data .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 )
127+ data_to_scale_dim = {0 : 0 }
128+ dim0 , dim1 , dim2 = float8_data .shape
129+ if dim1 == self .scale .shape [1 ]:
130+ data_to_scale_dim [1 ] = 1
131+ else :
132+ data_to_scale_dim [2 ] = 1
133+
134+ return self .__class__ (
135+ float8_data ,
136+ self .scale ,
137+ self .activation_scale_ub ,
138+ data_to_scale_dim ,
139+ self .dtype
140+ )
141+
88142 @classmethod
89143 def from_float (
90144 cls ,
@@ -106,14 +160,18 @@ def from_float(
106160 else :
107161 w = w .t ()
108162
109- wq , w_scale = torch .ops .triton .quantize_fp8_row (w )
110- # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
163+ data_to_scale_dim = {0 : 0 }
164+ if w .ndim == 3 :
165+ data_to_scale_dim [1 ] = 1
166+
167+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
111168 dtype = w .dtype
112169 del w
113170 return FbgemmFp8Tensor (
114171 wq ,
115172 w_scale ,
116173 activation_scale_ub = activation_scale_ub ,
174+ data_to_scale_dim = data_to_scale_dim ,
117175 dtype = dtype ,
118176 )
119177
@@ -169,6 +227,8 @@ def _(func, types, args, kwargs):
169227
170228 a_data = xq
171229 b_data = weight_tensor .float8_data
230+ assert b_data .is_contiguous (), "weight for bmm must be contiguous"
231+
172232 orig_out_features = b_data .shape [- 2 ]
173233
174234 res = torch .ops .fbgemm .f8f8bf16_rowwise_batched (
@@ -269,6 +329,65 @@ def _(func, types, args, kwargs):
269329 )
270330
271331
332+ @implements (aten .cat .default )
333+ def _ (func , types , args , kwargs ):
334+ tensors , dim = fill_defaults (args , 2 , [[], 0 ])
335+ tensor_0 = tensors [0 ]
336+ if dim < 0 :
337+ dim = tensor_0 .ndim + dim
338+
339+ for i in range (1 , len (tensors )):
340+ assert tensor_0 .float8_data .ndim == tensors [i ].float8_data .ndim
341+ assert tensor_0 .scale .ndim == tensors [i ].scale .ndim
342+ assert tensor_0 .activation_scale_ub == tensors [i ].activation_scale_ub
343+ assert tensor_0 .data_to_scale_dim == tensors [i ].data_to_scale_dim
344+
345+
346+ float8_data = [t .float8_data for t in tensors ]
347+ scale = [t .scale for t in tensors ]
348+
349+ # with rowwise quantization, dimension of float8_data and
350+ # origianl shape will be the same, so original dim argument applies
351+ # to float8_data
352+ cat_float8_data = aten .cat .default (float8_data , dim )
353+
354+ # if cat dimension has a corresponding scale dimension, then we'll concat the corresponding
355+ # scale dimension, otherwise, we'll just use the existing scale
356+ if dim in tensor_0 .data_to_scale_dim :
357+ cat_scale = aten .cat .default (scale , dim = tensor_0 .data_to_scale_dim [dim ])
358+ else :
359+ cat_scale = scale [0 ]
360+
361+ new = tensor_0 .__class__ (
362+ cat_float8_data , cat_scale , tensor_0 .activation_scale_ub , tensor_0 .data_to_scale_dim , tensor_0 .dtype
363+ )
364+ return return_and_correct_aliasing (func , args , kwargs , new )
365+
366+
367+ @implements (aten .transpose .int )
368+ def _ (func , types , args , kwargs ):
369+ self , dim0 , dim1 = args
370+ float8_data = self .float8_data .transpose (dim0 , dim1 ).contiguous ()
371+ data_to_scale_dim = self .data_to_scale_dim .copy ()
372+
373+ if dim0 in data_to_scale_dim :
374+ data_to_scale_dim [dim1 ] = data_to_scale_dim [dim0 ]
375+ del data_to_scale_dim [dim0 ]
376+ elif dim1 in data_to_scale_dim :
377+ data_to_scale_dim [dim0 ] = data_to_scale_dim [dim1 ]
378+ del data_to_scale_dim [dim1 ]
379+
380+ new = self .__class__ (
381+ float8_data ,
382+ self .scale ,
383+ self .activation_scale_ub ,
384+ data_to_scale_dim ,
385+ self .dtype
386+ )
387+ return return_and_correct_aliasing (
388+ func , args , kwargs , new
389+ )
390+
272391to_fbgemm_fp8 = FbgemmFp8Tensor .from_float
273392
274393
0 commit comments