Skip to content

Commit 64d47fd

Browse files
Modifying memory estimation options and minor changes
ghstack-source-id: 5f09824 Pull Request resolved: #435
1 parent 1ec2ece commit 64d47fd

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

estimation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,18 @@ def estimate_memory(job_config: JobConfig):
4949
# fake tensor doesn't work with fused rmsnorm
5050
if (
5151
job_config.model.norm_type == "fused_rmsnorm"
52-
and job_config.estimate.mode == "fake"
52+
and not job_config.memory_estimation.disable_fake_mode
5353
):
5454
logger.info(
5555
"Fused RMSNorm is not supported yet under fake estimation mode. "
5656
"Switching to rmsnorm."
5757
)
5858
job_config.model.norm_type = "rmsnorm"
5959

60+
if job_config.training.compile:
61+
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
62+
job_config.training.compile = False
63+
6064
parallel_dims = ParallelDims(
6165
dp=job_config.training.data_parallel_degree,
6266
tp=job_config.training.tensor_parallel_degree,
@@ -107,7 +111,7 @@ def loss_fn(pred, labels):
107111
model_config.vocab_size = tokenizer.n_words
108112
model_config.max_seq_len = job_config.training.seq_len
109113

110-
with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():
114+
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():
111115

112116
logger.info(
113117
f"Building {model_name} {job_config.model.flavor} with {model_config}"
@@ -198,7 +202,7 @@ def loss_fn(pred, labels):
198202
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
199203
)
200204
print(f"Tracker Max: {tracker_peak / gib} GiB")
201-
if job_config.estimate.mode == "real":
205+
if job_config.memory_estimation.disable_fake_mode and peak_active > 0:
202206
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
203207
gc.enable()
204208

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then
3131
fi
3232

3333
# Check if --estimate.memory=True is in the arguments
34-
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
34+
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
3535
# Calculate WORLD_SIZE as the product of NGPU and NNODES
3636
# Export WORLD_SIZE and LOCAL_RANK
3737
export WORLD_SIZE=$((NGPU * NNODES))

test_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def build_test_list():
265265
),
266266
OverrideDefinitions(
267267
[
268-
["--estimate.memory=True", "--estimate.mode=real"],
268+
[
269+
"--memory_estimation.enabled",
270+
]
269271
],
270272
"FSDP2 Memory Tracking and Estimation",
271273
"fsdp2_mem_tracker",

torchtitan/config_manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,18 +486,18 @@ def __init__(self):
486486
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
487487
)
488488

489-
# estimation mode settings
489+
# memory estimation settings
490490
self.parser.add_argument(
491-
"--estimate.memory",
491+
"--memory_estimation.enabled",
492492
help="Whether to estimate memory usage for FSDP",
493-
default=False,
493+
action="store_true",
494494
)
495495

496496
self.parser.add_argument(
497-
"--estimate.mode",
498-
type=str,
499-
default="fake",
500-
help="Mode of estimation to use ['fake', 'real']",
497+
"--memory_estimation.disable_fake_mode",
498+
help="Whether to estimate memory under FakeTensorMode",
499+
default=False,
500+
action="store_true",
501501
)
502502

503503
def parse_args(self, args_list: list = sys.argv[1:]):

0 commit comments

Comments
 (0)