Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,12 @@ m = nn.Sequential(
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute
# and optionally distributed communications in float8
convert_to_float8_training(m)

# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down Expand Up @@ -94,7 +85,8 @@ config = Float8LinearConfig(
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
convert_to_float8_training(m, config=config)

# enable torch.compile for competitive performance
# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down
Loading