44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from collections .abc import Callable
87from dataclasses import dataclass
98from functools import cached_property
109
@@ -23,21 +22,23 @@ class ParallelDims:
2322 cp : int
2423 tp : int
2524 pp : int
25+ ep : int
2626 world_size : int
2727 enable_loss_parallel : bool
2828
2929 def __post_init__ (self ):
3030 self ._validate ()
3131
3232 def _validate (self ):
33- dp_replicate , dp_shard , cp , tp , pp = (
33+ dp_replicate , dp_shard , cp , tp , pp , ep = (
3434 self .dp_replicate ,
3535 self .dp_shard ,
3636 self .cp ,
3737 self .tp ,
3838 self .pp ,
39+ self .ep ,
3940 )
40- for d in (dp_replicate , cp , tp , pp ):
41+ for d in (dp_replicate , cp , tp , pp , ep ):
4142 assert d >= 1 , "Parallelism degree should be >= 1, except for dp_shard"
4243
4344 assert dp_shard == - 1 or dp_shard >= 1 , " dp_shard must -1 or >=1."
@@ -50,7 +51,84 @@ def _validate(self):
5051 f"cp({ cp } ) * tp({ tp } ) * pp({ pp } ) != WORLD_SIZE({ self .world_size } )"
5152 )
5253
54+ if ep > 1 :
55+ # EP would borrow all cp and some dp_shard degree
56+ assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
57+
5358 def build_mesh (self , device_type : str ) -> DeviceMesh :
59+ # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
60+ # is not very clean, due to the limited support from DeviceMesh
61+ # for creating two staggered meshes. Will improve.
62+ if self .ep > 1 :
63+ return self ._build_mesh_with_ep (device_type )
64+ else :
65+ return self ._build_mesh_without_ep (device_type )
66+
67+ def _build_mesh_with_ep (self , device_type : str ) -> DeviceMesh :
68+ # With ep, dp_shard and ep are derived submeshes:
69+ # dp_shard = dp_shard_mod_ep * dp_shard_in_ep
70+ # ep = dp_shard_in_ep * cp
71+ dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
72+ dp_shard_in_ep = self .ep // self .cp
73+
74+ dims = []
75+ names = []
76+ for d , name in zip (
77+ [
78+ self .pp ,
79+ self .dp_replicate ,
80+ dp_shard_mod_ep ,
81+ dp_shard_in_ep ,
82+ self .cp ,
83+ self .tp ,
84+ ],
85+ ["pp" , "dp_replicate" , "dp_shard_mod_ep" , "dp_shard_in_ep" , "cp" , "tp" ],
86+ ):
87+ # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
88+ # helps the MoE layers do mixed precision training
89+ if d > 1 or name == "dp_shard_mod_ep" :
90+ dims .append (d )
91+ names .append (name )
92+
93+ logger .info (f"Building { len (dims )} -D device mesh with { names } , { dims } " )
94+ mesh = init_device_mesh (device_type , dims , mesh_dim_names = names )
95+
96+ # Create all the submesh here to ensure all required process groups are
97+ # initialized:
98+ # Mesh for data loading (no communication on this mesh)
99+ dp_mesh_dim_names = []
100+ # Mesh for param sharding
101+ dp_shard_cp_mesh_dim_names = []
102+ # Mesh for loss all-reduce
103+ dp_cp_mesh_dim_names = []
104+ # Mesh for ep
105+ ep_mesh_dim_names = []
106+
107+ if self .dp_replicate_enabled :
108+ dp_mesh_dim_names .append ("dp_replicate" )
109+ dp_cp_mesh_dim_names .append ("dp_replicate" )
110+ # dp_shard_mod_ep is always needed, even if it's 1
111+ dp_mesh_dim_names .append ("dp_shard_mod_ep" )
112+ dp_shard_cp_mesh_dim_names .append ("dp_shard_mod_ep" )
113+ dp_cp_mesh_dim_names .append ("dp_shard_mod_ep" )
114+ if "dp_shard_in_ep" in names :
115+ dp_mesh_dim_names .append ("dp_shard_in_ep" )
116+ dp_shard_cp_mesh_dim_names .append ("dp_shard_in_ep" )
117+ dp_cp_mesh_dim_names .append ("dp_shard_in_ep" )
118+ ep_mesh_dim_names .append ("dp_shard_in_ep" )
119+ if self .cp_enabled :
120+ dp_shard_cp_mesh_dim_names .append ("cp" )
121+ dp_cp_mesh_dim_names .append ("cp" )
122+ ep_mesh_dim_names .append ("cp" )
123+
124+ mesh [tuple (dp_mesh_dim_names )]._flatten (mesh_dim_name = "dp" )
125+ mesh [tuple (dp_shard_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_shard_cp" )
126+ mesh [tuple (dp_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_cp" )
127+ mesh [tuple (ep_mesh_dim_names )]._flatten (mesh_dim_name = "ep" )
128+
129+ return mesh
130+
131+ def _build_mesh_without_ep (self , device_type : str ) -> DeviceMesh :
54132 dims = []
55133 names = []
56134 for d , name in zip (
@@ -61,17 +139,8 @@ def build_mesh(self, device_type: str) -> DeviceMesh:
61139 dims .append (d )
62140 names .append (name )
63141
64- return self ._build_mesh (device_type , dims , names , init_device_mesh )
65-
66- def _build_mesh (
67- self ,
68- device_type : str ,
69- dims : list [int ],
70- names : list [str ],
71- init_device_mesh_fn : Callable ,
72- ) -> DeviceMesh :
73142 logger .info (f"Building { len (dims )} -D device mesh with { names } , { dims } " )
74- mesh = init_device_mesh_fn (device_type , dims , mesh_dim_names = names )
143+ mesh = init_device_mesh (device_type , dims , mesh_dim_names = names )
75144
76145 # Create all the submesh here to ensure all required process groups are
77146 # initialized:
@@ -143,3 +212,12 @@ def loss_parallel_enabled(self):
143212 @cached_property
144213 def non_data_parallel_size (self ):
145214 return self .cp * self .tp * self .pp
215+
216+ @property
217+ def ep_enabled (self ):
218+ return self .ep > 1
219+
220+ @property
221+ def dense_params_mesh_ndim (self ):
222+ # Note: EP params mesh ndim is 1 more due to the 'ep' mesh
223+ return self .dp_replicate_enabled + self .fsdp_enabled + self .tp_enabled
0 commit comments