Skip to content

Commit 65b028b

Browse files
Adding FSDP Memory Tracking and Estimation
ghstack-source-id: c8ed20f Pull Request resolved: #425
1 parent abb9e15 commit 65b028b

File tree

3 files changed

+241
-4
lines changed

3 files changed

+241
-4
lines changed

estimation.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import gc
9+
import os
10+
11+
import torch
12+
import torch.nn.functional as F
13+
from torch._guards import active_fake_mode
14+
from torch._subclasses.fake_tensor import FakeTensorMode
15+
from torch.distributed import destroy_process_group
16+
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17+
from torch.distributed.tensor.parallel import loss_parallel
18+
from torch.testing._internal.distributed.fake_pg import FakeStore
19+
20+
from torchtitan.config_manager import JobConfig
21+
from torchtitan.datasets import create_tokenizer
22+
from torchtitan.float8_linear import build_fp8_linear
23+
from torchtitan.logging_utils import init_logger, logger
24+
from torchtitan.lr_scheduling import get_lr_schedulers
25+
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
26+
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
27+
from train import build_optimizers
28+
29+
30+
def estimate_memory(job_config: JobConfig):
31+
init_logger()
32+
logger.info("Estimating memory usage...")
33+
gc.disable()
34+
gc.collect(1)
35+
36+
# Get the world size
37+
world_size = int(os.environ["WORLD_SIZE"])
38+
39+
# if tp > or pp > 1, we exit
40+
if (
41+
job_config.training.tensor_parallel_degree > 1
42+
or job_config.experimental.pipeline_parallel_degree > 1
43+
):
44+
logger.info(
45+
"Tensor parallelism and pipeline parallelism are not supported yet."
46+
)
47+
return
48+
49+
# fake tensor doesn't work with fused rmsnorm
50+
if (
51+
job_config.model.norm_type == "fused_rmsnorm"
52+
and job_config.estimate.mode == "fake"
53+
):
54+
logger.info(
55+
"Fused RMSNorm is not supported yet under fake estimation mode. "
56+
"Switching to rmsnorm."
57+
)
58+
job_config.model.norm_type = "rmsnorm"
59+
60+
parallel_dims = ParallelDims(
61+
dp=job_config.training.data_parallel_degree,
62+
tp=job_config.training.tensor_parallel_degree,
63+
pp=job_config.experimental.pipeline_parallel_degree,
64+
world_size=world_size,
65+
enable_loss_parallel=job_config.training.enable_loss_parallel,
66+
)
67+
68+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
69+
torch.cuda.set_device(device)
70+
71+
# init fake pg
72+
store = FakeStore()
73+
torch.distributed.init_process_group(
74+
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
75+
)
76+
77+
# build meshes
78+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
79+
80+
if not parallel_dims.dp_enabled:
81+
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
82+
return
83+
84+
model_name = job_config.model.name
85+
86+
# build tokenizer
87+
tokenizer_type = model_name_to_tokenizer[model_name]
88+
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
89+
90+
# loss_parallel enables dispatching to efficient loss operators
91+
loss_parallel_ctx = (
92+
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
93+
)
94+
95+
# loss fn can be shared by pipeline-parallel or non-pp execution
96+
def loss_fn(pred, labels):
97+
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
98+
99+
# build model (using meta init)
100+
model_cls = model_name_to_cls[model_name]
101+
model_config = models_config[model_name][job_config.model.flavor]
102+
# set the model configs from training inputs:
103+
# 1. norm type to decide which norm layer to use
104+
# 2. vocab size from tokenizer
105+
# 3. max_seq_len base on inputs
106+
model_config.norm_type = job_config.model.norm_type
107+
model_config.vocab_size = tokenizer.n_words
108+
model_config.max_seq_len = job_config.training.seq_len
109+
110+
with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():
111+
112+
logger.info(
113+
f"Building {model_name} {job_config.model.flavor} with {model_config}"
114+
)
115+
with torch.device("meta"):
116+
whole_model = model_cls.from_model_args(model_config)
117+
118+
# apply fp8 linear module swap
119+
if job_config.training.fp8_linear:
120+
build_fp8_linear(whole_model, job_config)
121+
122+
# apply PT-D DP/TP parallelisms and activation checkpointing
123+
model_parts = [whole_model]
124+
model_parts = [
125+
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
126+
for m in model_parts
127+
]
128+
129+
init_device = "cuda"
130+
for model in model_parts:
131+
model.to_empty(device=init_device)
132+
133+
if not active_fake_mode():
134+
whole_model.init_weights()
135+
136+
# build optimizer after applying parallelisms to the model
137+
optimizers = build_optimizers(model_parts, job_config)
138+
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
139+
140+
for model in model_parts:
141+
model.train()
142+
logger.info(f"Vocab size: {model_config.vocab_size}")
143+
# Create a dummy batch instead of loading from a dataset
144+
batch = (
145+
torch.randint(
146+
0,
147+
model_config.vocab_size,
148+
(job_config.training.batch_size, model_config.max_seq_len),
149+
device="cuda",
150+
),
151+
torch.randint(
152+
0,
153+
model_config.vocab_size,
154+
(job_config.training.batch_size, model_config.max_seq_len),
155+
device="cuda",
156+
),
157+
)
158+
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
159+
fsdp_memtracker.track_inputs(batch)
160+
161+
with fsdp_memtracker:
162+
for iter_idx in range(2):
163+
input_ids, labels = batch
164+
# train step
165+
with loss_parallel_ctx():
166+
pred = whole_model(input_ids)
167+
loss = loss_fn(pred, labels)
168+
del pred
169+
loss.backward()
170+
171+
# clip gradients
172+
for model in model_parts:
173+
torch.nn.utils.clip_grad_norm_(
174+
model.parameters(), job_config.training.max_norm, foreach=True
175+
)
176+
# optimizer step
177+
optimizers.step()
178+
lr_schedulers.step()
179+
optimizers.zero_grad()
180+
print(f"Peak Memory at iter: {iter_idx}")
181+
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
182+
if iter_idx == 0:
183+
fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state
184+
gc.collect(1)
185+
186+
fsdp_memtracker.display_modulewise_snapshots(
187+
depth=3, units="MiB", tabulate=True
188+
)
189+
mem_stats = torch.cuda.memory_stats()
190+
peak_active = mem_stats["active_bytes.all.peak"]
191+
peak_reserved = mem_stats["reserved_bytes.all.peak"]
192+
num_retries = mem_stats["num_alloc_retries"]
193+
dev = torch.device(torch.cuda.current_device())
194+
tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"]
195+
gib = 1024**3
196+
print(
197+
f"peak active: {peak_active / gib} GiB | peak reserved:"
198+
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
199+
)
200+
print(f"Tracker Max: {tracker_peak / gib} GiB")
201+
if job_config.estimate.mode == "real":
202+
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
203+
gc.enable()
204+
205+
206+
if __name__ == "__main__":
207+
config = JobConfig()
208+
config.parse_args()
209+
try:
210+
estimate_memory(config)
211+
finally:
212+
destroy_process_group()

run_llama_train.sh

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ set -ex
1010
# libUV is a scalable backend for TCPStore which is used in processGroup
1111
# rendezvous. This is the recommended backend for distributed training.
1212
export USE_LIBUV=1
13-
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
13+
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}
1414

1515
# use envs as local overrides for convenience
1616
# e.g.
1717
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
1818

1919
NGPU=${NGPU:-"8"}
20+
NNODES=${NNODES:-"1"}
2021

2122
# by default log just rank 0 output,
2223
LOG_RANK=${LOG_RANK:-0}
@@ -29,6 +30,16 @@ if [ $# -ne 0 ]; then
2930
overrides="$*"
3031
fi
3132

32-
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
33-
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
34-
train.py --job.config_file ${CONFIG_FILE} $overrides
33+
# Check if --estimate.memory=True is in the arguments
34+
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
35+
# Calculate WORLD_SIZE as the product of NGPU and NNODES
36+
# Export WORLD_SIZE and LOCAL_RANK
37+
export WORLD_SIZE=$((NGPU * NNODES))
38+
export LOCAL_RANK=0
39+
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
40+
else
41+
# Call train.py if not in estimation mode
42+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
43+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
44+
train.py --job.config_file ${CONFIG_FILE} $overrides
45+
fi

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,20 @@ def __init__(self):
480480
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
481481
)
482482

483+
# estimation mode settings
484+
self.parser.add_argument(
485+
"--estimate.memory",
486+
help="Whether to estimate memory usage for FSDP",
487+
default=False,
488+
)
489+
490+
self.parser.add_argument(
491+
"--estimate.mode",
492+
type=str,
493+
default="fake",
494+
help="Mode of estimation to use ['fake', 'real']",
495+
)
496+
483497
def parse_args(self, args_list: list = sys.argv[1:]):
484498
args, cmd_args = self.parse_args_from_command_line(args_list)
485499
config_file = getattr(args, "job.config_file", None)

0 commit comments

Comments
 (0)