From 060780d9801eb85c08da836e2b3a27c273562a8c Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sat, 3 Aug 2024 22:29:14 -0700 Subject: [PATCH] [fix] float8 should be applied on all model_parts [ghstack-poisoned] --- torchtitan/float8.py | 18 ++++++++++++++---- train.py | 4 ++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 4dc7122b29..043b18326b 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -13,6 +13,8 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance +from typing import List, Union + import torch import torch.nn as nn @@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module): f"{self.config.enable_fsdp_float8_all_gather}" ) - def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module): + def precompute_float8_dynamic_scale_for_fsdp( + self, model: Union[nn.Module, List[nn.Module]] + ): if not self.enabled: return @@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module): from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp - precompute_float8_dynamic_scale_for_fsdp(model) + models = [model] if isinstance(model, nn.Module) else model + for m in models: + precompute_float8_dynamic_scale_for_fsdp(m) - def sync_float8_amax_and_scale_history(self, model: nn.Module): + def sync_float8_amax_and_scale_history( + self, model: Union[nn.Module, List[nn.Module]] + ): if not self.enabled: return @@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module): sync_float8_amax_and_scale_history ) - self._sync_float8_amax_and_scale_history(model) + models = [model] if isinstance(model, nn.Module) else model + for m in models: + self._sync_float8_amax_and_scale_history(m) diff --git a/train.py b/train.py index 5c62debfde..58d23307cc 100644 --- a/train.py +++ b/train.py @@ -307,7 +307,7 @@ def loss_fn(pred, labels): ) # sync float8 amaxes and scales - float8_handler.sync_float8_amax_and_scale_history(model) + float8_handler.sync_float8_amax_and_scale_history(model_parts) # optimizer step checkpoint.maybe_wait_for_staging() @@ -316,7 +316,7 @@ def loss_fn(pred, labels): # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) losses_since_last_log.append(loss)