1313# Note: Performance
1414# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
1515
16+ from typing import List , Union
17+
1618import torch
1719import torch .nn as nn
1820
@@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module):
103105 f"{ self .config .enable_fsdp_float8_all_gather } "
104106 )
105107
106- def precompute_float8_dynamic_scale_for_fsdp (self , model : nn .Module ):
108+ def precompute_float8_dynamic_scale_for_fsdp (
109+ self , model : Union [nn .Module , List [nn .Module ]]
110+ ):
107111 if not self .enabled :
108112 return
109113
@@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
112116
113117 from torchao .float8 import precompute_float8_dynamic_scale_for_fsdp
114118
115- precompute_float8_dynamic_scale_for_fsdp (model )
119+ models = [model ] if isinstance (model , nn .Module ) else model
120+ for m in models :
121+ precompute_float8_dynamic_scale_for_fsdp (m )
116122
117- def sync_float8_amax_and_scale_history (self , model : nn .Module ):
123+ def sync_float8_amax_and_scale_history (
124+ self , model : Union [nn .Module , List [nn .Module ]]
125+ ):
118126 if not self .enabled :
119127 return
120128
@@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module):
136144 sync_float8_amax_and_scale_history
137145 )
138146
139- self ._sync_float8_amax_and_scale_history (model )
147+ models = [model ] if isinstance (model , nn .Module ) else model
148+ for m in models :
149+ self ._sync_float8_amax_and_scale_history (m )
0 commit comments