|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from dataclasses import dataclass |
8 | | -from functools import cached_property |
9 | 7 |
|
10 | | -from torch.distributed.device_mesh import init_device_mesh |
11 | | -from torchtitan.logging import logger |
12 | | -from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama |
| 8 | +from torchtitan.parallelisms.parallel_dims import ParallelDims |
| 9 | +from torchtitan.parallelisms.parallelize_llama import parallelize_llama |
| 10 | +from torchtitan.parallelisms.pipeline_llama import pipeline_llama |
13 | 11 | from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule |
14 | 12 |
|
15 | 13 |
|
|
28 | 26 | "llama2": pipeline_llama, |
29 | 27 | "llama3": pipeline_llama, |
30 | 28 | } |
31 | | - |
32 | | - |
33 | | -@dataclass |
34 | | -class ParallelDims: |
35 | | - dp: int |
36 | | - tp: int |
37 | | - pp: int |
38 | | - world_size: int |
39 | | - enable_loss_parallel: bool |
40 | | - dp_type: str |
41 | | - |
42 | | - def __post_init__(self): |
43 | | - self.dp_type = self.dp_type.lower() |
44 | | - self._validate() |
45 | | - |
46 | | - def _validate(self): |
47 | | - dp, tp, pp = self.dp, self.tp, self.pp |
48 | | - if dp == -1: |
49 | | - self.dp = dp = self.world_size // (tp * pp) |
50 | | - assert dp >= 1, dp |
51 | | - assert tp >= 1, tp |
52 | | - assert pp >= 1, pp |
53 | | - assert ( |
54 | | - dp * tp * pp == self.world_size |
55 | | - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" |
56 | | - assert self.dp_type in ("fsdp", "ddp") |
57 | | - |
58 | | - def build_mesh(self, device_type): |
59 | | - dims = [] |
60 | | - names = [] |
61 | | - for d, name in zip( |
62 | | - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True |
63 | | - ): |
64 | | - if d > 1: |
65 | | - dims.append(d) |
66 | | - names.append(name) |
67 | | - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") |
68 | | - names = tuple(names) |
69 | | - return init_device_mesh(device_type, dims, mesh_dim_names=names) |
70 | | - |
71 | | - @property |
72 | | - def dp_enabled(self): |
73 | | - return self.dp > 1 |
74 | | - |
75 | | - @property |
76 | | - def tp_enabled(self): |
77 | | - return self.tp > 1 |
78 | | - |
79 | | - @property |
80 | | - def pp_enabled(self): |
81 | | - return self.pp > 1 |
82 | | - |
83 | | - @property |
84 | | - def loss_parallel_enabled(self): |
85 | | - return self.tp > 1 and self.enable_loss_parallel |
86 | | - |
87 | | - @cached_property |
88 | | - def model_parallel_size(self): |
89 | | - return self.tp * self.pp |
0 commit comments