Skip to content

Commit 5ea7c97

Browse files
author
Andrew Gu
committed
Made some stylistic changes to apply_dp
ghstack-source-id: 37d6c4e Pull Request resolved: #446
1 parent 3f717c5 commit 5ea7c97

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -458,23 +458,21 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
458458
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
459459
)
460460
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
461-
462461
for layer_id, transformer_block in model.layers.items():
463-
# As an optimization, do not reshard after forward for the last
464-
# transformer block since FSDP would prefetch it immediately.
465-
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
466-
# per microbatch.
467-
reshard_after_forward = (
468-
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
469-
)
462+
if parallel_dims.pp_enabled:
463+
# For PP, do not reshard after forward to avoid per-microbatch
464+
# all-gathers, which can be expensive and non-overlapped
465+
reshard_after_forward = False
466+
else:
467+
# As an optimization, do not reshard after forward for the last
468+
# transformer block since FSDP would prefetch it immediately
469+
reshard_after_forward = int(layer_id) < len(model.layers) - 1
470470
fully_shard(
471471
transformer_block,
472472
**fsdp_config,
473473
reshard_after_forward=reshard_after_forward,
474474
)
475-
model.layers[layer_id] = transformer_block
476-
477-
model = fully_shard(
475+
fully_shard(
478476
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
479477
)
480478

0 commit comments

Comments
 (0)