@@ -34,21 +34,12 @@ m = nn.Sequential(
3434x = torch.randn(4096 , 2048 , device = " cuda" , dtype = torch.bfloat16)
3535optimizer = torch.optim.SGD(m.parameters(), lr = 0.1 )
3636
37- # optional: filter modules from being eligible for float8 conversion
38- def module_filter_fn (mod : torch.nn.Module, fqn : str ):
39- # don't convert the last module
40- if fqn == " 1" :
41- return False
42- # don't convert linear modules with weight dimensions not divisible by 16
43- if isinstance (mod, torch.nn.Linear):
44- if mod.in_features % 16 != 0 or mod.out_features % 16 != 0 :
45- return False
46- return True
47-
48- # convert specified `torch.nn.Linear` modules to `Float8Linear`
49- convert_to_float8_training(m, module_filter_fn = module_filter_fn)
50-
51- # enable torch.compile for competitive performance
37+ # convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute
38+ # and optionally distributed communications in float8
39+ convert_to_float8_training(m)
40+
41+ # enable torch.compile to generate fused kernels for float8 scaling and casting,
42+ # which improves performance
5243m = torch.compile(m)
5344
5445# toy training loop
@@ -94,7 +85,8 @@ config = Float8LinearConfig(
9485# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
9586convert_to_float8_training(m, config = config)
9687
97- # enable torch.compile for competitive performance
88+ # enable torch.compile to generate fused kernels for float8 scaling and casting,
89+ # which improves performance
9890m = torch.compile(m)
9991
10092# toy training loop
0 commit comments