1111
1212from float8_experimental .float8_tensor import Float8Tensor
1313from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14+ import float8_experimental .config as config
1415
1516
1617class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -38,6 +39,24 @@ def backward(ctx, gradY):
3839 None ,
3940 )
4041
42+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
43+ """
44+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+ """
46+ return module .cast_to_float8 (args [0 ])
47+
48+ def cast_dldy_to_float8_e5m2_pre_hook (module , grad_output ):
49+ """
50+ Hook to cast the incoming gradient to `torch.float8_e5m2`
51+ """
52+ gradY = grad_output [0 ]
53+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
54+ gradY_scaled = gradY * gradY_scale
55+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
56+ gradY_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
57+ # TODO fix: the next op in the backward does not see this, it sees grad_output[0]
58+ return (gradY_fp8 ,)
59+
4160
4261class Float8DynamicLinear (torch .nn .Linear ):
4362 """
@@ -48,9 +67,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4867 def __init__ (self , * args , ** kwargs ):
4968 super ().__init__ (* args , ** kwargs )
5069 self .add_weight_tag ()
70+ self .use_activation_hooks = config .dynamic_use_activation_hooks
5171
5272 def forward (self , x ):
53- x_fp8 = self .cast_to_float8 (x )
73+ # cast x to float8_e4m3fn
74+ if self .use_activation_hooks :
75+ x_fp8 = x
76+ else :
77+ x_fp8 = self .cast_to_float8 (x )
78+
79+ # cast w to float8_e4m3fn
5480 if getattr (self , "_w_fp8" , None ) is not None : # FSDP handled the cast
5581 w_fp8 = self ._w_fp8
5682 else :
@@ -59,7 +85,10 @@ def forward(self, x):
5985 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
6086
6187 # Cast gradY to float8_e5m2 during backward
62- y = self .cast_to_float8e5m2_bw (y )
88+ if self .use_activation_hooks :
89+ pass
90+ else :
91+ y = self .cast_to_float8e5m2_bw (y )
6392
6493 return y
6594
@@ -69,6 +98,7 @@ def add_weight_tag(self):
6998 self .weight ._is_fp8_weight = True
7099
71100 def cast_to_float8 (self , inpt_tensor ):
101+ # TODO rename this function to clarify e4m3
72102 scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
73103 return Float8Tensor .to_float8 (
74104 inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -92,4 +122,10 @@ def from_float(cls, mod, emulate: bool = False):
92122 new_mod .bias = mod .bias
93123 new_mod .emulate = emulate
94124 new_mod .add_weight_tag ()
125+
126+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
127+ if new_mod .use_activation_hooks :
128+ # install the hooks
129+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
130+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_pre_hook )
95131 return new_mod
0 commit comments