Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import gc
import os

import torch
import torch.nn.functional as F
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.parallel import loss_parallel
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers


def estimate_memory(job_config: JobConfig):
init_logger()
logger.info("Estimating memory usage...")
gc.disable()
gc.collect(1)

# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

# if tp > or pp > 1, we exit
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
):
logger.info(
"Tensor parallelism and pipeline parallelism are not supported yet."
)
return

# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and job_config.estimate.mode == "fake"
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)

# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

if not parallel_dims.dp_enabled:
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
return

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
)

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
# 3. max_seq_len base on inputs
model_config.norm_type = job_config.model.norm_type
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)

if not active_fake_mode():
whole_model.init_weights()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with loss_parallel_ctx():
pred = whole_model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# optimizer step
optimizers.step()
lr_schedulers.step()
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
if iter_idx == 0:
fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state
gc.collect(1)

fsdp_memtracker.display_modulewise_snapshots(
depth=3, units="MiB", tabulate=True
)
mem_stats = torch.cuda.memory_stats()
peak_active = mem_stats["active_bytes.all.peak"]
peak_reserved = mem_stats["reserved_bytes.all.peak"]
num_retries = mem_stats["num_alloc_retries"]
dev = torch.device(torch.cuda.current_device())
tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"]
gib = 1024**3
print(
f"peak active: {peak_active / gib} GiB | peak reserved:"
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.estimate.mode == "real":
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()


if __name__ == "__main__":
config = JobConfig()
config.parse_args()
try:
estimate_memory(config)
finally:
destroy_process_group()
19 changes: 15 additions & 4 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ set -ex
# libUV is a scalable backend for TCPStore which is used in processGroup
# rendezvous. This is the recommended backend for distributed training.
export USE_LIBUV=1
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NNODES=${NNODES:-"1"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}
Expand All @@ -29,6 +30,16 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
export LOCAL_RANK=0
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
else
# Call train.py if not in estimation mode
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
fi
14 changes: 14 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,20 @@ def __init__(self):
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)

# estimation mode settings
self.parser.add_argument(
"--estimate.memory",
help="Whether to estimate memory usage for FSDP",
default=False,
)

self.parser.add_argument(
"--estimate.mode",
type=str,
default="fake",
help="Mode of estimation to use ['fake', 'real']",
)

def parse_args(self, args_list: list = sys.argv[1:]):
args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
Expand Down