From 5c8c982c180fea53221348e4589c873f3d799373 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 5 Apr 2024 09:39:46 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- train.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/train.py b/train.py index 849ae78498..16b80e748f 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.tensor.parallel import loss_parallel from torchtrain.checkpoint import CheckpointManager, IntervalType @@ -101,14 +100,6 @@ def build_optimizer(model, job_config: JobConfig): return optimizer -def build_grad_scaler(model): - # TODO: FSDP2 does not support sharded grad scaler yet. - # TODO: if enabled, grad scaler's states need saving & loading in checkpointing - enable_grad_scaling = False - logger.info("Gradient scaling not enabled") - return ShardedGradScaler(enabled=enable_grad_scaling) - - # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -205,10 +196,6 @@ def main(job_config: JobConfig): optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) - # build grad scaler which is effective only when mixed precision training - # is enabled with fp16 param dtype under FSDP - scaler = build_grad_scaler(model) - metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance @@ -290,25 +277,18 @@ def main(job_config: JobConfig): else contextlib.nullcontext() ): loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + # backward + loss.backward() - # backward on scaled loss to create scaled gradients - scaler.scale(loss).backward() - - # clip gradients (after unscaling gradients of the optimizer's params) - scaler.unscale_(optimizer) + # clip gradients torch.nn.utils.clip_grad_norm_( model.parameters(), job_config.training.max_norm, foreach=True ) # optimizer step - # If gradients don't contain infs/NaNs, optimizer.step() is then called; - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) + optimizer.step() scheduler.step() - # updates the scale for next iteration - scaler.update() - current_loss = loss.item() losses_since_last_log.append(current_loss)