2929from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
3030from torchtitan .logging import logger
3131from torchtitan .parallelisms .parallel_dims import ParallelDims
32+ from torchtitan .parallelisms .utils import check_strided_sharding_enabled
3233
3334
3435def parallelize_llama (
@@ -83,6 +84,7 @@ def parallelize_llama(
8384 reduce_dtype = TORCH_DTYPE_MAP [
8485 job_config .training .mixed_precision_reduce
8586 ],
87+ tp_enabled = parallel_dims .tp_enabled ,
8688 pp_enabled = parallel_dims .pp_enabled ,
8789 )
8890 else :
@@ -289,6 +291,7 @@ def apply_fsdp(
289291 dp_mesh : DeviceMesh ,
290292 param_dtype : torch .dtype ,
291293 reduce_dtype : torch .dtype ,
294+ tp_enabled : bool ,
292295 pp_enabled : bool ,
293296):
294297 """
@@ -297,6 +300,12 @@ def apply_fsdp(
297300 mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
298301 fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
299302
303+ # TODO: remove this check once PyTorch 2.5 is released. We can safely assume
304+ # that users won't use a nightly build which is older than 20240809 by then.
305+ if tp_enabled :
306+ # check if strided sharding is enabled, which is necessary for 2D/3D DCP
307+ check_strided_sharding_enabled ()
308+
300309 for layer_id , transformer_block in model .layers .items ():
301310 if pp_enabled :
302311 # For PP, do not reshard after forward to avoid per-microbatch
@@ -313,18 +322,6 @@ def apply_fsdp(
313322 )
314323 fully_shard (model , ** fsdp_config , reshard_after_forward = not pp_enabled )
315324
316- if pp_enabled :
317- # TODO
318- # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
319- # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
320- # without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
321- # removed after strided sharding is landed in DCP.
322- for module in model .modules ():
323- assert len (module ._load_state_dict_pre_hooks ) <= 1
324- module ._load_state_dict_pre_hooks .clear ()
325- assert len (module ._state_dict_pre_hooks ) <= 1
326- module ._state_dict_pre_hooks .clear ()
327-
328325 logger .info ("Applied FSDP to the model" )
329326
330327
0 commit comments