Skip to content

Commit 0cf1281

Browse files
authored
Update README.md for float8
1 parent 7a35695 commit 0cf1281

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

torchao/float8/README.md

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,12 @@ m = nn.Sequential(
3434
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
3535
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
3636

37-
# optional: filter modules from being eligible for float8 conversion
38-
def module_filter_fn(mod: torch.nn.Module, fqn: str):
39-
# don't convert the last module
40-
if fqn == "1":
41-
return False
42-
# don't convert linear modules with weight dimensions not divisible by 16
43-
if isinstance(mod, torch.nn.Linear):
44-
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
45-
return False
46-
return True
47-
48-
# convert specified `torch.nn.Linear` modules to `Float8Linear`
49-
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
50-
51-
# enable torch.compile for competitive performance
37+
# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute
38+
# and optionally distributed communications in float8
39+
convert_to_float8_training(m)
40+
41+
# enable torch.compile to generate fused kernels for float8 scaling and casting,
42+
# which improves performance
5243
m = torch.compile(m)
5344

5445
# toy training loop
@@ -94,7 +85,8 @@ config = Float8LinearConfig(
9485
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
9586
convert_to_float8_training(m, config=config)
9687

97-
# enable torch.compile for competitive performance
88+
# enable torch.compile to generate fused kernels for float8 scaling and casting,
89+
# which improves performance
9890
m = torch.compile(m)
9991

10092
# toy training loop

0 commit comments

Comments
 (0)