Skip to content

Commit fefa9e0

Browse files
committed
Update
[ghstack-poisoned]
1 parent 1c064e0 commit fefa9e0

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,17 @@ def apply_dp(
484484
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
485485
)
486486

487+
if parallel_dims.pp_enabled:
488+
# TODO
489+
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
490+
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
491+
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
492+
# removed after strided sharding is landed in DCP.
493+
for module in model.modules():
494+
assert len(module._load_state_dict_pre_hooks) <= 1
495+
module._load_state_dict_pre_hooks.clear()
496+
assert len(module._state_dict_pre_hooks) <= 1
497+
module._state_dict_pre_hooks.clear()
487498
logger.info("Applied FSDP to the model")
488499
return model
489500

0 commit comments

Comments
 (0)