From a273a49a1482c1ab7a3c9bf436d7b4cce88e988b Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Fri, 28 Jun 2024 15:30:41 -0700 Subject: [PATCH 1/2] Modifying memory estimation options and minor changes [ghstack-poisoned] --- estimation.py | 10 +++++++--- run_llama_train.sh | 2 +- test_runner.py | 7 +++++-- torchtitan/config_manager.py | 13 ++++++------- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/estimation.py b/estimation.py index 70bd3e6011..04b66adf16 100644 --- a/estimation.py +++ b/estimation.py @@ -49,7 +49,7 @@ def estimate_memory(job_config: JobConfig): # fake tensor doesn't work with fused rmsnorm if ( job_config.model.norm_type == "fused_rmsnorm" - and job_config.estimate.mode == "fake" + and job_config.memory_estimation.fake_mode_only ): logger.info( "Fused RMSNorm is not supported yet under fake estimation mode. " @@ -57,6 +57,10 @@ def estimate_memory(job_config: JobConfig): ) job_config.model.norm_type = "rmsnorm" + if job_config.training.compile: + logger.info("Compile mode is not supported yet. " "Switching to Eager mode.") + job_config.training.compile = False + parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, @@ -107,7 +111,7 @@ def loss_fn(pred, labels): model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len - with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext(): + with FakeTensorMode() if job_config.memory_estimation.fake_mode_only else contextlib.nullcontext(): logger.info( f"Building {model_name} {job_config.model.flavor} with {model_config}" @@ -198,7 +202,7 @@ def loss_fn(pred, labels): f" {peak_reserved / gib} GiB | num_retries: {num_retries}" ) print(f"Tracker Max: {tracker_peak / gib} GiB") - if job_config.estimate.mode == "real": + if not job_config.memory_estimation.fake_mode_only and peak_active > 0: print(f"Tracker Accuracy: {tracker_peak/peak_active}") gc.enable() diff --git a/run_llama_train.sh b/run_llama_train.sh index ca2001a8da..cf4943a6eb 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then fi # Check if --estimate.memory=True is in the arguments -if echo "$overrides" | grep -q -- "--estimate.memory=True"; then +if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then # Calculate WORLD_SIZE as the product of NGPU and NNODES # Export WORLD_SIZE and LOCAL_RANK export WORLD_SIZE=$((NGPU * NNODES)) diff --git a/test_runner.py b/test_runner.py index 63377edd09..7cfafe2a0b 100755 --- a/test_runner.py +++ b/test_runner.py @@ -265,11 +265,14 @@ def build_test_list(): ), OverrideDefinitions( [ - ["--estimate.memory=True", "--estimate.mode=real"], + [ + "--memory_estimation.enabled", + "--memory_estimation.fake_mode_only", + ] ], "FSDP2 Memory Tracking and Estimation", "fsdp2_mem_tracker", - ngpu=4, + ngpu=8, ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2ff216e146..d3794653b1 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -480,18 +480,17 @@ def __init__(self): help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled", ) - # estimation mode settings + # memory estimation settings self.parser.add_argument( - "--estimate.memory", + "--memory_estimation.enabled", help="Whether to estimate memory usage for FSDP", - default=False, + action="store_true", ) self.parser.add_argument( - "--estimate.mode", - type=str, - default="fake", - help="Mode of estimation to use ['fake', 'real']", + "--memory_estimation.fake_mode_only", + help="Whether to estimate memory under FakeTensorMode", + action="store_true", ) def parse_args(self, args_list: list = sys.argv[1:]): From 6dc4cb07368fe611e1f8a2c0c73fa6ccee471697 Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Mon, 1 Jul 2024 13:22:02 -0700 Subject: [PATCH 2/2] Update on "Modifying memory estimation options and minor changes" As per suggestions from tianyu-l in #425, the config options are now: `./run_llama_train.sh --memory_estimation.enabled --memory_estimation.fake_mode_only` [ghstack-poisoned] --- estimation.py | 6 +++--- test_runner.py | 3 +-- torchtitan/config_manager.py | 3 ++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/estimation.py b/estimation.py index 04b66adf16..e82a7b7199 100644 --- a/estimation.py +++ b/estimation.py @@ -49,7 +49,7 @@ def estimate_memory(job_config: JobConfig): # fake tensor doesn't work with fused rmsnorm if ( job_config.model.norm_type == "fused_rmsnorm" - and job_config.memory_estimation.fake_mode_only + and not job_config.memory_estimation.disable_fake_mode ): logger.info( "Fused RMSNorm is not supported yet under fake estimation mode. " @@ -111,7 +111,7 @@ def loss_fn(pred, labels): model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len - with FakeTensorMode() if job_config.memory_estimation.fake_mode_only else contextlib.nullcontext(): + with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext(): logger.info( f"Building {model_name} {job_config.model.flavor} with {model_config}" @@ -202,7 +202,7 @@ def loss_fn(pred, labels): f" {peak_reserved / gib} GiB | num_retries: {num_retries}" ) print(f"Tracker Max: {tracker_peak / gib} GiB") - if not job_config.memory_estimation.fake_mode_only and peak_active > 0: + if job_config.memory_estimation.disable_fake_mode and peak_active > 0: print(f"Tracker Accuracy: {tracker_peak/peak_active}") gc.enable() diff --git a/test_runner.py b/test_runner.py index 7cfafe2a0b..cba63544aa 100755 --- a/test_runner.py +++ b/test_runner.py @@ -267,12 +267,11 @@ def build_test_list(): [ [ "--memory_estimation.enabled", - "--memory_estimation.fake_mode_only", ] ], "FSDP2 Memory Tracking and Estimation", "fsdp2_mem_tracker", - ngpu=8, + ngpu=4, ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d3794653b1..6930bb7c00 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -488,8 +488,9 @@ def __init__(self): ) self.parser.add_argument( - "--memory_estimation.fake_mode_only", + "--memory_estimation.disable_fake_mode", help="Whether to estimate memory under FakeTensorMode", + default=False, action="store_true", )