Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
import torch.nn.functional as F
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel

from torchtrain.checkpoint import CheckpointManager, IntervalType
Expand Down Expand Up @@ -101,14 +100,6 @@ def build_optimizer(model, job_config: JobConfig):
return optimizer


def build_grad_scaler(model):
# TODO: FSDP2 does not support sharded grad scaler yet.
# TODO: if enabled, grad scaler's states need saving & loading in checkpointing
enable_grad_scaling = False
logger.info("Gradient scaling not enabled")
return ShardedGradScaler(enabled=enable_grad_scaling)


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
Expand Down Expand Up @@ -205,10 +196,6 @@ def main(job_config: JobConfig):
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

# build grad scaler which is effective only when mixed precision training
# is enabled with fp16 param dtype under FSDP
scaler = build_grad_scaler(model)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -290,25 +277,18 @@ def main(job_config: JobConfig):
else contextlib.nullcontext()
):
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
# backward
loss.backward()

# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
optimizer.step()
scheduler.step()

# updates the scale for next iteration
scaler.update()

current_loss = loss.item()
losses_since_last_log.append(current_loss)

Expand Down