Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. FP8 support
3. Context Parallel
Expand Down
11 changes: 10 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,21 @@ def build_test_list(args):
OverrideDefinitions(
[
[
"--training.compile",
"--training.compile --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
"1D compile",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/2d_compile/",
],
],
"2D compile",
),
OverrideDefinitions(
[
[
Expand Down
44 changes: 35 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
Expand Down Expand Up @@ -360,20 +360,49 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

logger.info("Applied Tensor Parallelism to the model")

# apply AC + torch.compile
ac_config = job_config.activation_checkpoint
enable_compile = job_config.training.compile
for layer_id, transformer_block in model.layers.items():
if ac_config.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
if enable_compile:
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious how far we are from being able to enable fullgraph=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fullgraph=True should work already when we moving to per-TransformerBlock compile, maybe I can add that flag too.

I ended up just turn dynamic=False in the current full mode compile #297 in this case we can't full_graph=True yet as FSDP still graph breaking, but for that case we should also captured each TransformerBlock as full graphs already

model.layers[layer_id] = transformer_block

if ac_config.mode in ("full", "selective"):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
if (
enable_compile
and ac_config.mode == "selective"
and ac_config.selective_ac_option == "op"
):
# some temp flags for torch.compile enablement + SAC
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
if enable_compile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can this error check be moved up to line 216?

if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)
logger.info("Compiled each TransformerBlock with torch.compile")

# apply DP (FSDP2)
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in model.layers.items():
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately.
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
Expand All @@ -387,12 +416,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block

model = fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

return model
14 changes: 0 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,6 @@ def loss_fn(pred, labels):

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
if job_config.training.compile:
if (
job_config.activation_checkpoint.mode == "selective"
and job_config.activation_checkpoint.selective_ac_option == "op"
):
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
logger.info("Compiling model with torch.compile")
# Dynamic shape have issues with distributed, turn dynamic off as Transformer
# training is static_shape TODO: resolve dynamic shape issue and restore defaults
model = torch.compile(model, dynamic=False)

train_state = TrainState()

# train loop
Expand Down