2323
2424from float8_experimental .float8_tensor import (
2525 Float8Tensor ,
26+ GemmInputRole ,
27+ LinearMMConfig ,
2628 ScaledMMConfig ,
2729 to_fp8_no_autograd ,
2830)
@@ -85,12 +87,12 @@ def forward(
8587 fp8_scale_dL_dY ,
8688 scale_fn_name ,
8789 is_amax_initialized ,
88- mm_config : ScaledMMConfig ,
90+ linear_mm_config : LinearMMConfig ,
8991 ):
9092 ctx .save_for_backward (fp8_amax_dL_dY , fp8_amax_history_dL_dY , fp8_scale_dL_dY )
9193 ctx .scale_fn_name = scale_fn_name
9294 ctx .is_amax_initialized = is_amax_initialized
93- ctx .mm_config = mm_config
95+ ctx .linear_mm_config = linear_mm_config
9496 return tensor
9597
9698 @staticmethod
@@ -113,7 +115,11 @@ def backward(ctx, go):
113115 fp8_amax_dL_dY .fill_ (tensor_to_amax (go ))
114116
115117 res = to_fp8_no_autograd (
116- go , fp8_scale_dL_dY , e5m2_dtype , mm_config = ctx .mm_config
118+ go ,
119+ fp8_scale_dL_dY ,
120+ e5m2_dtype ,
121+ linear_mm_config = ctx .linear_mm_config ,
122+ gemm_input_role = GemmInputRole .DL_DY ,
117123 )
118124 empty_grads = None , None , None , None , None , None
119125 return res , * empty_grads
@@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):
192198
193199 self .create_buffers ()
194200
195- # Defines the behavior of the matmul in the forward and backward pass
196- self .forward_config = ScaledMMConfig (
197- emulate , True if not emulate else False , False , config .pad_inner_dim
198- )
199- self .backward_config = ScaledMMConfig (
200- emulate , False , False , config .pad_inner_dim
201+ # TODO(future): user level configuration of gemms
202+ self .linear_mm_config = LinearMMConfig (
203+ # x
204+ ScaledMMConfig (
205+ emulate , True if not emulate else False , False , config .pad_inner_dim
206+ ),
207+ # w
208+ ScaledMMConfig (
209+ emulate , True if not emulate else False , False , config .pad_inner_dim
210+ ),
211+ # dL_dY
212+ ScaledMMConfig (emulate , False , False , config .pad_inner_dim ),
201213 )
202214
203215 # Note: is_amax_initialized is not a buffer to avoid data dependent
@@ -308,11 +320,12 @@ def cast_x_to_float8(
308320 self .fp8_scale_x ,
309321 e4m3_dtype ,
310322 self .fp8_amax_x ,
311- self .forward_config ,
323+ linear_mm_config = self .linear_mm_config ,
324+ gemm_input_role = GemmInputRole .X ,
312325 )
313326 else :
314327 assert self .scaling_type_x is TensorScalingType .DYNAMIC
315- x_fp8 = cast_to_float8_e4m3_dynamic (x , self .forward_config )
328+ x_fp8 = cast_to_float8_e4m3_dynamic (x , self .linear_mm_config )
316329 return x_fp8
317330
318331 def cast_w_to_float8 (
@@ -339,14 +352,17 @@ def cast_w_to_float8(
339352 self .fp8_scale_w ,
340353 e4m3_dtype ,
341354 self .fp8_amax_w ,
342- self .forward_config ,
355+ linear_mm_config = self .linear_mm_config ,
356+ gemm_input_role = GemmInputRole .W ,
343357 )
344358 else :
345359 assert self .scaling_type_w is TensorScalingType .DYNAMIC
346360 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
347361 w_fp8 = self .weight
348362 else :
349- w_fp8 = cast_to_float8_e4m3_dynamic (self .weight , self .forward_config )
363+ w_fp8 = cast_to_float8_e4m3_dynamic (
364+ self .weight , self .linear_mm_config , gemm_input_role = GemmInputRole .W
365+ )
350366 return w_fp8
351367
352368 def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
@@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
359375 self .fp8_scale_dL_dY ,
360376 scale_fn_name ,
361377 self .is_amax_initialized ,
362- self .backward_config ,
378+ self .linear_mm_config ,
363379 )
364380 else :
365381 assert self .scaling_type_dL_dY is TensorScalingType .DYNAMIC
366- y = cast_to_float8_e5m2_dynamic_bw (y , self .backward_config )
382+ y = cast_to_float8_e5m2_dynamic_bw (y , self .linear_mm_config )
367383 return y
368384
369385 def float8_pre_forward (self , x ):
@@ -457,7 +473,7 @@ def from_float(
457473 new_mod .weight = torch .nn .Parameter (
458474 WeightWithDynamicFloat8CastTensor (
459475 new_mod .weight ,
460- new_mod .forward_config ,
476+ new_mod .linear_mm_config ,
461477 )
462478 )
463479 else :
@@ -468,7 +484,7 @@ def from_float(
468484 new_mod .fp8_amax_w ,
469485 new_mod .fp8_amax_history_w ,
470486 new_mod .fp8_scale_w ,
471- new_mod .forward_config ,
487+ new_mod .linear_mm_config ,
472488 new_mod .is_amax_initialized ,
473489 )
474490 )
0 commit comments