From 65959384d0eca8221b409ee0f72146bc6295e4ae Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 24 Apr 2024 18:15:39 -0700 Subject: [PATCH 01/17] torch.compile each TransformerBlock instead of the whole model This way we could temporarily enable 2-D parallel compile, and it make more sense to do transformer block compile in the future with PP anyways. We should figure out the dynamic shape issue though --- test_runner.py | 8 +++++++- torchtitan/parallelisms/parallelize_llama.py | 13 +++++++++++++ train.py | 12 ------------ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/test_runner.py b/test_runner.py index 80d75ad86a..cdd708e49b 100755 --- a/test_runner.py +++ b/test_runner.py @@ -40,7 +40,7 @@ class OverrideDefinitions: integration_tests_flavors["debug_model.toml"] = [ OverrideDefinitions( [ - ["--training.compile"], + ["--training.compile --model.norm_type=rmsnorm"], ], "1D compile", ), @@ -50,6 +50,12 @@ class OverrideDefinitions: ], "Eager mode 2DParallel", ), + OverrideDefinitions( + [ + ["--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm"], + ], + "2DParallel compile", + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 644a6163ac..8cc8d3cece 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -227,6 +227,19 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block = checkpoint_wrapper( transformer_block, job_config.activation_checkpoint ) + + if job_config.training.compile: + if ( + job_config.activation_checkpoint.mode == "selective" + and job_config.activation_checkpoint.selective_ac_option == "op" + ): + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + logger.info(f"Compiling TransformerBlock {layer_id} with torch.compile") + # TODO: dynamic shape have some issues so we turn it off for now. + transformer_block = torch.compile(transformer_block, dynamic=False) + # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = layer_id < len(model.layers) - 1 diff --git a/train.py b/train.py index 5dddd14c4b..0f9b2fdd13 100644 --- a/train.py +++ b/train.py @@ -219,18 +219,6 @@ def loss_fn(pred, labels): metric_logger = build_metric_logger(job_config) - # torch.compile model for improved performance - if job_config.training.compile: - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.activation_checkpoint.selective_ac_option == "op" - ): - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiling model with torch.compile") - model = torch.compile(model) - train_state = TrainState() # train loop From 68157d4145aa8d541217120fbaeb4d5134ac7427 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 24 Apr 2024 18:22:14 -0700 Subject: [PATCH 02/17] lint --- test_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test_runner.py b/test_runner.py index cdd708e49b..b571b22081 100755 --- a/test_runner.py +++ b/test_runner.py @@ -52,7 +52,9 @@ class OverrideDefinitions: ), OverrideDefinitions( [ - ["--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm"], + [ + "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm" + ], ], "2DParallel compile", ), From dc394b2a5638ec7e988585b71fe3552e9a2fce71 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 16:29:56 -0700 Subject: [PATCH 03/17] comment --- torchtitan/parallelisms/parallelize_llama.py | 9 +-------- train.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 8cc8d3cece..34e0cd79cb 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -229,14 +229,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) if job_config.training.compile: - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.activation_checkpoint.selective_ac_option == "op" - ): - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info(f"Compiling TransformerBlock {layer_id} with torch.compile") + # turn on per-transformer block compile after AC wrappnig and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. transformer_block = torch.compile(transformer_block, dynamic=False) diff --git a/train.py b/train.py index 0f9b2fdd13..3a31d9f56c 100644 --- a/train.py +++ b/train.py @@ -221,6 +221,17 @@ def loss_fn(pred, labels): train_state = TrainState() + if job_config.training.compile: + if ( + job_config.activation_checkpoint.mode == "selective" + and job_config.activation_checkpoint.selective_ac_option == "op" + ): + # some flags for torch.compile enablement + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + logger.info(f"Compiling each TransformerBlock with torch.compile") + # train loop model.train() From a03877bdeccd4837ccd430706b688f206cb1416d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 16:30:53 -0700 Subject: [PATCH 04/17] order --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 3a31d9f56c..6669208462 100644 --- a/train.py +++ b/train.py @@ -219,8 +219,6 @@ def loss_fn(pred, labels): metric_logger = build_metric_logger(job_config) - train_state = TrainState() - if job_config.training.compile: if ( job_config.activation_checkpoint.mode == "selective" @@ -232,6 +230,8 @@ def loss_fn(pred, labels): ) logger.info(f"Compiling each TransformerBlock with torch.compile") + train_state = TrainState() + # train loop model.train() From c74961a5614bca7bd0e1628a3e088fde797cf9b3 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 16:37:02 -0700 Subject: [PATCH 05/17] lint --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 6669208462..57d5c64a9e 100644 --- a/train.py +++ b/train.py @@ -228,7 +228,7 @@ def loss_fn(pred, labels): torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( True ) - logger.info(f"Compiling each TransformerBlock with torch.compile") + logger.info("Compiling each TransformerBlock with torch.compile") train_state = TrainState() From 8dc90b0b1355aac5e29f913f920fcc773ef5ce5b Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 21:58:30 -0700 Subject: [PATCH 06/17] group AC + torch.compile together --- README.md | 5 ++-- torchtitan/parallelisms/parallelize_llama.py | 28 +++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 02d50e9f98..076932f789 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Our guiding principles when building `torchtitan`: #### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development. Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly). -Key features available +### Key features available 1. [FSDP2 with per param sharding](docs/fsdp.md) 2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) @@ -34,7 +34,8 @@ Key features available We report our [Performance](docs/performance.md) verified on 64 A100 GPUs -## Coming soon +### Coming soon + 1. Async checkpointing 2. FP8 support 3. Context Parallel diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 34e0cd79cb..19c77200f0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -213,6 +213,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info("Applied Tensor Parallelism to the model") + # apply AC + torch.compile + ac_mode = job_config.activation_checkpoint.mode + for layer_id, transformer_block in enumerate(model.layers): + if ac_mode in ("full", "selective"): + transformer_block = checkpoint_wrapper( + transformer_block, job_config.activation_checkpoint + ) + logger.info(f"Applied {ac_mode} activation checkpointing to the model") + + if job_config.training.compile: + # turn on per-transformer block compile after AC wrappnig and before FSDP + # TODO: dynamic shape have some issues so we turn it off for now. + transformer_block = torch.compile(transformer_block, dynamic=False) + + # apply DP (FSDP2) if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -220,19 +235,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32 ) - ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in enumerate(model.layers): - if job_config.activation_checkpoint.mode in ("full", "selective"): - transformer_block = checkpoint_wrapper( - transformer_block, job_config.activation_checkpoint - ) - - if job_config.training.compile: - # turn on per-transformer block compile after AC wrappnig and before FSDP - # TODO: dynamic shape have some issues so we turn it off for now. - transformer_block = torch.compile(transformer_block, dynamic=False) - # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = layer_id < len(model.layers) - 1 @@ -243,8 +247,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) model.layers[layer_id] = transformer_block model = fully_shard(model, **fsdp_config) - if ac_mode in ("full", "selective"): - logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") return model From 90429de674404661e9e74e6eac074317d799b51d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 22:07:00 -0700 Subject: [PATCH 07/17] refactor and reorganize more --- torchtitan/parallelisms/parallelize_llama.py | 27 +++++++++++--------- train.py | 11 -------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 19c77200f0..8d441c7316 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -149,6 +149,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( job_config ) + loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the first embedding and the last linear proj layer # 2. Parallelize the root norm layer over the sequence dim @@ -162,12 +163,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), "output": col_parallel_strategy( input_layouts=Shard(1), - output_layouts=( - Shard(-1) - if parallel_dims.loss_parallel_enabled - else Replicate() - ), - use_local_output=not parallel_dims.loss_parallel_enabled, + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, ), "norm": SequenceParallel(), "layers.0": PrepareModuleInput( @@ -214,13 +211,19 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info("Applied Tensor Parallelism to the model") # apply AC + torch.compile - ac_mode = job_config.activation_checkpoint.mode - for layer_id, transformer_block in enumerate(model.layers): - if ac_mode in ("full", "selective"): - transformer_block = checkpoint_wrapper( - transformer_block, job_config.activation_checkpoint + ac_config = job_config.activation_checkpoint + if job_config.training.compile: + if ac_config.mode == "selective" and ac_config.selective_ac_option == "op": + # some temp flags for torch.compile enablement + SAC + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True ) - logger.info(f"Applied {ac_mode} activation checkpointing to the model") + logger.info("Compiling each TransformerBlock with torch.compile") + + for layer_id, transformer_block in enumerate(model.layers): + if ac_config.mode in ("full", "selective"): + transformer_block = checkpoint_wrapper(transformer_block, ac_config) + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") if job_config.training.compile: # turn on per-transformer block compile after AC wrappnig and before FSDP diff --git a/train.py b/train.py index 57d5c64a9e..0f9b2fdd13 100644 --- a/train.py +++ b/train.py @@ -219,17 +219,6 @@ def loss_fn(pred, labels): metric_logger = build_metric_logger(job_config) - if job_config.training.compile: - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.activation_checkpoint.selective_ac_option == "op" - ): - # some flags for torch.compile enablement - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiling each TransformerBlock with torch.compile") - train_state = TrainState() # train loop From c887859c027d9b92ecf0182916b2c79675058248 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 22:08:59 -0700 Subject: [PATCH 08/17] fix typo --- torchtitan/parallelisms/parallelize_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 8d441c7316..561b833845 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -226,7 +226,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") if job_config.training.compile: - # turn on per-transformer block compile after AC wrappnig and before FSDP + # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. transformer_block = torch.compile(transformer_block, dynamic=False) From 640ebbc9b089c4cdbd8343fee9d7006fa6a10ae9 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 22:11:48 -0700 Subject: [PATCH 09/17] lint --- torchtitan/parallelisms/parallelize_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 561b833845..84363e8f7d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -223,7 +223,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): for layer_id, transformer_block in enumerate(model.layers): if ac_config.mode in ("full", "selective"): transformer_block = checkpoint_wrapper(transformer_block, ac_config) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + logger.info( + f"Applied {ac_config.mode} activation checkpointing to the model" + ) if job_config.training.compile: # turn on per-transformer block compile after AC wrapping and before FSDP From 2e6f1cfecaa25e11a15000c18ced37ce5a9038e1 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 22:30:34 -0700 Subject: [PATCH 10/17] more fixes --- torchtitan/parallelisms/parallelize_llama.py | 27 ++++++++++---------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 84363e8f7d..705acf883d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -212,25 +212,26 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # apply AC + torch.compile ac_config = job_config.activation_checkpoint - if job_config.training.compile: - if ac_config.mode == "selective" and ac_config.selective_ac_option == "op": - # some temp flags for torch.compile enablement + SAC - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiling each TransformerBlock with torch.compile") - for layer_id, transformer_block in enumerate(model.layers): if ac_config.mode in ("full", "selective"): transformer_block = checkpoint_wrapper(transformer_block, ac_config) - logger.info( - f"Applied {ac_config.mode} activation checkpointing to the model" - ) - if job_config.training.compile: # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. - transformer_block = torch.compile(transformer_block, dynamic=False) + model.layers[layer_id] = torch.compile(transformer_block, dynamic=False) + + if ac_config.mode in ("full", "selective"): + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + if ( + job_config.training.compile + and ac_config.mode == "selective" + and ac_config.selective_ac_option == "op" + ): + # some temp flags for torch.compile enablement + SAC + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) + logger.info("Compiled each TransformerBlock with torch.compile") # apply DP (FSDP2) if parallel_dims.dp_enabled: From 6eb6cf4475da1077f8040c1d532977026e5eb0a5 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 22:39:10 -0700 Subject: [PATCH 11/17] fix logging --- torchtitan/parallelisms/parallelize_llama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 705acf883d..de0a5a5392 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -212,10 +212,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # apply AC + torch.compile ac_config = job_config.activation_checkpoint + enable_compile = job_config.training.compile for layer_id, transformer_block in enumerate(model.layers): if ac_config.mode in ("full", "selective"): transformer_block = checkpoint_wrapper(transformer_block, ac_config) - if job_config.training.compile: + if enable_compile: # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. model.layers[layer_id] = torch.compile(transformer_block, dynamic=False) @@ -223,7 +224,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if ac_config.mode in ("full", "selective"): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") if ( - job_config.training.compile + enable_compile and ac_config.mode == "selective" and ac_config.selective_ac_option == "op" ): @@ -231,6 +232,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( True ) + if enable_compile: logger.info("Compiled each TransformerBlock with torch.compile") # apply DP (FSDP2) From 9e1e08df6bbf46669c932cda06ed60cb6ee76b82 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 23:00:28 -0700 Subject: [PATCH 12/17] explicitly throw if use fused_rmsnorm + compile to avoid surprise --- torchtitan/parallelisms/parallelize_llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index de0a5a5392..8eef0ced51 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -233,6 +233,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): True ) if enable_compile: + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." + ) logger.info("Compiled each TransformerBlock with torch.compile") # apply DP (FSDP2) From 50457064e02df81e53bbdb8be971c816575d1c37 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 25 Apr 2024 23:33:18 -0700 Subject: [PATCH 13/17] fix bug around AC assign submodule --- torchtitan/parallelisms/parallelize_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 8eef0ced51..ded495b98a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -219,7 +219,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if enable_compile: # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. - model.layers[layer_id] = torch.compile(transformer_block, dynamic=False) + transformer_block = torch.compile(transformer_block, dynamic=False) + model.layers[layer_id] = transformer_block if ac_config.mode in ("full", "selective"): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") From ce8e4805903227b2eaddd9c35ac8b3a7fe58fee4 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 26 Apr 2024 09:54:29 -0700 Subject: [PATCH 14/17] temp changes --- train_configs/llama3_8b.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3b338d1135..ddedb66d1a 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] +norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] @@ -32,11 +32,11 @@ warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 pipeline_parallel_degree = 1 fp8_linear = "" -compile = false -dataset = "c4" +compile = true +dataset = "c4_mini" [checkpoint] enable_checkpoint = false @@ -46,6 +46,6 @@ interval = 500 model_weights_only = false export_dtype = "float32" -[activation_checkpoint] -mode = 'selective' # ['none', 'selective', 'full'] -selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy +# [activation_checkpoint] +# mode = 'full' # ['none', 'selective', 'full'] +# selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy From 11e02fe8184e2c47b0faf02b14f7f05f9fc1f47e Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 17 May 2024 16:00:33 -0700 Subject: [PATCH 15/17] try out reuse cache flag --- torchtitan/parallelisms/parallelize_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index b7620811f8..8888ed3d9a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -215,6 +215,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if enable_compile: # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. + torch._dynamo.config.inline_inbuilt_nn_modules = True transformer_block = torch.compile(transformer_block, dynamic=False) model.layers[layer_id] = transformer_block From bf5c857f3f10092fdca121ee55b2e9878d883fc8 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 21 May 2024 21:27:50 -0700 Subject: [PATCH 16/17] disable inline_nn_modules --- torchtitan/parallelisms/parallelize_llama.py | 4 +++- train_configs/llama3_8b.toml | 12 ++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 02582ff952..5c69ac4eff 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -369,7 +369,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if enable_compile: # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. - torch._dynamo.config.inline_inbuilt_nn_modules = True + # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate + # compile time. + # torch._dynamo.config.inline_inbuilt_nn_modules = True transformer_block = torch.compile(transformer_block, dynamic=False) model.layers[layer_id] = transformer_block diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index d90456a7ea..aaba99a215 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -32,11 +32,11 @@ warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 pipeline_parallel_degree = 1 fp8_linear = "" -compile = true -dataset = "c4_mini" +compile = false +dataset = "c4" [checkpoint] enable_checkpoint = false @@ -47,6 +47,6 @@ model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] -# [activation_checkpoint] -# mode = 'full' # ['none', 'selective', 'full'] -# selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy From 81dc9e37d5d919f454c44e11b9216be46702c63f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 21 May 2024 21:29:38 -0700 Subject: [PATCH 17/17] fix lint --- test_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_runner.py b/test_runner.py index e57d78042a..3bd2770f7f 100755 --- a/test_runner.py +++ b/test_runner.py @@ -131,7 +131,7 @@ def build_test_list(args): OverrideDefinitions( [ [ - "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm + "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", f"--job.dump_folder {args.output_dir}/2d_compile/", ], ],