Skip to content

Commit a4d88d1

Browse files
committed
[fix] float8 should be applied on all model_parts
ghstack-source-id: 52ed683 Pull Request resolved: #500
1 parent 8849580 commit a4d88d1

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

torchtitan/float8.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
1515

16+
from typing import List, Union
17+
1618
import torch
1719
import torch.nn as nn
1820

@@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module):
103105
f"{self.config.enable_fsdp_float8_all_gather}"
104106
)
105107

106-
def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
108+
def precompute_float8_dynamic_scale_for_fsdp(
109+
self, model: Union[nn.Module, List[nn.Module]]
110+
):
107111
if not self.enabled:
108112
return
109113

@@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
112116

113117
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
114118

115-
precompute_float8_dynamic_scale_for_fsdp(model)
119+
models = [model] if isinstance(model, nn.Module) else model
120+
for m in models:
121+
precompute_float8_dynamic_scale_for_fsdp(m)
116122

117-
def sync_float8_amax_and_scale_history(self, model: nn.Module):
123+
def sync_float8_amax_and_scale_history(
124+
self, model: Union[nn.Module, List[nn.Module]]
125+
):
118126
if not self.enabled:
119127
return
120128

@@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module):
136144
sync_float8_amax_and_scale_history
137145
)
138146

139-
self._sync_float8_amax_and_scale_history(model)
147+
models = [model] if isinstance(model, nn.Module) else model
148+
for m in models:
149+
self._sync_float8_amax_and_scale_history(m)

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def loss_fn(pred, labels):
307307
)
308308

309309
# sync float8 amaxes and scales
310-
float8_handler.sync_float8_amax_and_scale_history(model)
310+
float8_handler.sync_float8_amax_and_scale_history(model_parts)
311311

312312
# optimizer step
313313
checkpoint.maybe_wait_for_staging()
@@ -316,7 +316,7 @@ def loss_fn(pred, labels):
316316

317317
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
318318
# it issues a single all-reduce for all parameters at once for better performance
319-
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
319+
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)
320320

321321
losses_since_last_log.append(loss)
322322

0 commit comments

Comments
 (0)