Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 8cb170a

Browse files
committed
[TBD if for land] bring back torch.autograd.Function
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c1b8d0c Pull Request resolved: #316
1 parent c58fb5d commit 8cb170a

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

float8_experimental/float8_linear.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,54 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
7171
scale.copy_(new_scale)
7272

7373

74+
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
75+
@torch._dynamo.allow_in_graph
76+
class manual_float8_mm(torch.autograd.Function):
77+
"""
78+
Like torch.mm, but with X and W in float8
79+
"""
80+
81+
@staticmethod
82+
def forward(
83+
ctx,
84+
x_fp8,
85+
w_fp8_t,
86+
):
87+
ctx.save_for_backward(x_fp8, w_fp8_t)
88+
orig_shape = x_fp8.shape
89+
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
90+
res_bits = torch.mm(x_fp8_reshaped, w_fp8_t)
91+
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
92+
return res_bits
93+
94+
@staticmethod
95+
def backward(ctx, go_fp8):
96+
x_fp8, w_fp8_t = ctx.saved_tensors
97+
98+
go_fp8_orig_shape = go_fp8.shape
99+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
100+
101+
# calculate dL/dX
102+
dL_dX = torch.mm(
103+
go_fp8_reshaped,
104+
w_fp8_t.t(),
105+
)
106+
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
107+
108+
x_fp8_orig_shape = x_fp8.shape
109+
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])
110+
111+
# calculate dL/dW
112+
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
113+
# compared to than calculating `dL_dW_t = x_fp8_t @ go_fp8_reshaped`
114+
dL_dW = torch.mm(
115+
go_fp8_reshaped.t(),
116+
x_fp8_reshaped,
117+
)
118+
119+
return dL_dX, dL_dW.t()
120+
121+
74122
@torch._dynamo.allow_in_graph
75123
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
76124
"""
@@ -410,7 +458,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
410458
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
411459
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
412460

413-
y = torch.matmul(x_fp8, w_fp8.t())
461+
y = manual_float8_mm.apply(x_fp8, w_fp8.t())
414462

415463
# Cast gradY to float8_e5m2 during backward
416464
y = self.cast_y_to_float8_in_bw(y)

float8_experimental/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def choose_scaled_mm_config(
101101
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
102102
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
103103
return a_linear_mm_config.dL_dX
104-
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
104+
elif (a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY) or (
105+
a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X
106+
):
105107
assert (
106108
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
107109
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"

0 commit comments

Comments
 (0)