Skip to content

Commit f929df7

Browse files
committed
fix 2D parallel crash caused by all-reduce on 2D world_mesh
ghstack-source-id: 1c5bf79 Pull Request resolved: #105
1 parent ce048cd commit f929df7

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
129129
if parallel_dims.sp_enabled:
130130
# First we apply Sequence Parallelism if it's enabled
131131
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
132-
sp_degree = job_config.training.sequence_parallelism_degree
132+
sp_degree = job_config.training.sequence_parallel_degree
133133
# First:
134134
# 1. parallelize the first embedding and the last linear proj layer
135135
# 2. shard the first layer of transformer block

train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def main(job_config: JobConfig):
102102

103103
# build dataloader
104104
# need dp world size and rank
105-
# TODO: dp might not always be 0 so we need to handle that more carefully
106-
dp_degree = world_mesh.size(0)
107-
dp_rank = world_mesh.get_local_rank(0)
105+
dp_mesh = world_mesh["dp"]
106+
dp_degree = dp_mesh.size()
107+
dp_rank = dp_mesh.get_local_rank()
108108
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
109109
data_loader = build_dataloader_fn(
110110
tokenizer,
@@ -253,8 +253,8 @@ def main(job_config: JobConfig):
253253
np.max(losses_since_last_log),
254254
)
255255
global_avg_loss, global_max_loss = (
256-
dist_mean(avg_loss, world_mesh),
257-
dist_max(max_loss, world_mesh),
256+
dist_mean(avg_loss, dp_mesh),
257+
dist_max(max_loss, dp_mesh),
258258
)
259259

260260
time_delta = timer() - time_last_log

0 commit comments

Comments
 (0)