Skip to content

Commit 004b314

Browse files
authored
Throw warning if users are using old pytorch version that not including DTensor strided sharding (#507)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #507 **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ```
1 parent fa7fe1e commit 004b314

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3030
from torchtitan.logging import logger
3131
from torchtitan.parallelisms.parallel_dims import ParallelDims
32+
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
3233

3334

3435
def parallelize_llama(
@@ -83,6 +84,7 @@ def parallelize_llama(
8384
reduce_dtype=TORCH_DTYPE_MAP[
8485
job_config.training.mixed_precision_reduce
8586
],
87+
tp_enabled=parallel_dims.tp_enabled,
8688
pp_enabled=parallel_dims.pp_enabled,
8789
)
8890
else:
@@ -289,6 +291,7 @@ def apply_fsdp(
289291
dp_mesh: DeviceMesh,
290292
param_dtype: torch.dtype,
291293
reduce_dtype: torch.dtype,
294+
tp_enabled: bool,
292295
pp_enabled: bool,
293296
):
294297
"""
@@ -297,6 +300,12 @@ def apply_fsdp(
297300
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
298301
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
299302

303+
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
304+
# that users won't use a nightly build which is older than 20240809 by then.
305+
if tp_enabled:
306+
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
307+
check_strided_sharding_enabled()
308+
300309
for layer_id, transformer_block in model.layers.items():
301310
if pp_enabled:
302311
# For PP, do not reshard after forward to avoid per-microbatch
@@ -313,18 +322,6 @@ def apply_fsdp(
313322
)
314323
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
315324

316-
if pp_enabled:
317-
# TODO
318-
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
319-
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
320-
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
321-
# removed after strided sharding is landed in DCP.
322-
for module in model.modules():
323-
assert len(module._load_state_dict_pre_hooks) <= 1
324-
module._load_state_dict_pre_hooks.clear()
325-
assert len(module._state_dict_pre_hooks) <= 1
326-
module._state_dict_pre_hooks.clear()
327-
328325
logger.info("Applied FSDP to the model")
329326

330327

torchtitan/parallelisms/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from torchtitan.logging import logger
9+
10+
11+
def check_strided_sharding_enabled() -> None:
12+
# Correct 2D/3D DCP usage requires DTensor's strided sharding in PR
13+
# https://github.com/pytorch/pytorch/pull/130760. This function checks if users'
14+
# PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is
15+
# included when 2D/3D DCP is used.
16+
if "git" in torch.__version__: # pytorch is built from source
17+
# notify users to check if the commit hash is newer than 2024-08-09
18+
logger.warning(
19+
"detected that the pytorch is built from source. Please make sure the PR "
20+
"(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch "
21+
"for correct 2D/3D DCP usage."
22+
)
23+
elif torch.__version__ < "2.5.0.dev20240809":
24+
# the nightly build pytorch was built before 2024-08-09
25+
logger.warning(
26+
f"detected that the pytorch version {torch.__version__} is older than "
27+
"2.5.0.dev20240809. Please upgrade a newer version to include the change "
28+
"made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D "
29+
"DCP usage."
30+
)

0 commit comments

Comments
 (0)