diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1b4141598f..39773a6220 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -484,6 +484,17 @@ def apply_dp( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled ) + if parallel_dims.pp_enabled: + # TODO + # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since + # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even + # without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be + # removed after strided sharding is landed in DCP. + for module in model.modules(): + assert len(module._load_state_dict_pre_hooks) <= 1 + module._load_state_dict_pre_hooks.clear() + assert len(module._state_dict_pre_hooks) <= 1 + module._state_dict_pre_hooks.clear() logger.info("Applied FSDP to the model") return model