-
Notifications
You must be signed in to change notification settings - Fork 604
torch.compile each TransformerBlock instead of the whole model #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6595938
torch.compile each TransformerBlock instead of the whole model
wanchaol 68157d4
lint
wanchaol dc394b2
comment
wanchaol a03877b
order
wanchaol c74961a
lint
wanchaol 8dc90b0
group AC + torch.compile together
wanchaol 90429de
refactor and reorganize more
wanchaol c887859
fix typo
wanchaol 640ebbc
lint
wanchaol 2e6f1cf
more fixes
wanchaol 6eb6cf4
fix logging
wanchaol 9e1e08d
explicitly throw if use fused_rmsnorm + compile to avoid surprise
wanchaol 5045706
fix bug around AC assign submodule
wanchaol ce8e480
temp changes
wanchaol 50ceae8
Merge branch 'main' into compile_2d
wanchaol 6460a48
Merge branch 'main' into compile_2d
wanchaol 11e02fe
try out reuse cache flag
wanchaol 962a1e1
Merge branch 'main' into compile_2d
wanchaol bf5c857
disable inline_nn_modules
wanchaol 81dc9e3
fix lint
wanchaol File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -318,7 +318,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
| ), | ||
| "output": col_parallel_strategy( | ||
| input_layouts=Shard(1), | ||
| output_layouts=(Shard(-1) if loss_parallel else Replicate()), | ||
| output_layouts=Shard(-1) if loss_parallel else Replicate(), | ||
| use_local_output=not loss_parallel, | ||
| ), | ||
| "norm": SequenceParallel(), | ||
|
|
@@ -360,20 +360,49 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
|
|
||
| logger.info("Applied Tensor Parallelism to the model") | ||
|
|
||
| # apply AC + torch.compile | ||
| ac_config = job_config.activation_checkpoint | ||
| enable_compile = job_config.training.compile | ||
| for layer_id, transformer_block in model.layers.items(): | ||
| if ac_config.mode in ("full", "selective"): | ||
| transformer_block = checkpoint_wrapper(transformer_block, ac_config) | ||
| if enable_compile: | ||
| # turn on per-transformer block compile after AC wrapping and before FSDP | ||
| # TODO: dynamic shape have some issues so we turn it off for now. | ||
| # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate | ||
| # compile time. | ||
| # torch._dynamo.config.inline_inbuilt_nn_modules = True | ||
| transformer_block = torch.compile(transformer_block, dynamic=False) | ||
| model.layers[layer_id] = transformer_block | ||
|
|
||
| if ac_config.mode in ("full", "selective"): | ||
| logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") | ||
| if ( | ||
| enable_compile | ||
| and ac_config.mode == "selective" | ||
| and ac_config.selective_ac_option == "op" | ||
| ): | ||
| # some temp flags for torch.compile enablement + SAC | ||
| torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( | ||
| True | ||
| ) | ||
| if enable_compile: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can this error check be moved up to line 216? |
||
| if job_config.model.norm_type == "fused_rmsnorm": | ||
| raise NotImplementedError( | ||
| "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." | ||
| ) | ||
| logger.info("Compiled each TransformerBlock with torch.compile") | ||
|
|
||
| # apply DP (FSDP2) | ||
| if parallel_dims.dp_enabled: | ||
| dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
| assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
| mp_policy = MixedPrecisionPolicy( | ||
| param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
| reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], | ||
| ) | ||
| ac_mode = job_config.activation_checkpoint.mode | ||
| fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
| for layer_id, transformer_block in model.layers.items(): | ||
| if job_config.activation_checkpoint.mode in ("full", "selective"): | ||
| transformer_block = checkpoint_wrapper( | ||
| transformer_block, job_config.activation_checkpoint | ||
| ) | ||
| # As an optimization, do not reshard after forward for the last | ||
| # transformer block since FSDP would prefetch it immediately. | ||
| # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings | ||
|
|
@@ -387,12 +416,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
| reshard_after_forward=reshard_after_forward, | ||
| ) | ||
| model.layers[layer_id] = transformer_block | ||
|
|
||
| model = fully_shard( | ||
| model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled | ||
| ) | ||
| if ac_mode in ("full", "selective"): | ||
| logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
| logger.info("Applied FSDP to the model") | ||
|
|
||
| return model | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious how far we are from being able to enable
fullgraph=True?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
fullgraph=Trueshould work already when we moving to per-TransformerBlock compile, maybe I can add that flag too.I ended up just turn dynamic=False in the current full mode compile #297 in this case we can't full_graph=True yet as FSDP still graph breaking, but for that case we should also captured each TransformerBlock as full graphs already