Skip to content

Commit c99c4b4

Browse files
committed
[chat] add distributed impl
1 parent 9379cbd commit c99c4b4

File tree

9 files changed

+798
-0
lines changed

9 files changed

+798
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Requirements
2+
3+
```bash
4+
pip install cupy-cuda12x
5+
python -m cupyx.tools.install_library --cuda 12.x --library nccl
6+
```

applications/ColossalChat/coati/distributed/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Any, Dict
2+
3+
import ray.util.collective as cc
4+
import torch
5+
import torch.distributed.distributed_c10d as c10d
6+
from packaging.version import Version
7+
8+
9+
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
10+
rank = cc.get_rank(group_name)
11+
if rank == src:
12+
if Version(torch.__version__) >= Version("2.3.0"):
13+
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
14+
elif Version(torch.__version__) >= Version("1.13.0"):
15+
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
16+
else:
17+
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
18+
obj_tensor = obj_tensor.to(device)
19+
size_tensor = size_tensor.to(device)
20+
else:
21+
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
22+
cc.broadcast(size_tensor, src, group_name)
23+
if rank != src:
24+
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
25+
cc.broadcast(obj_tensor, src, group_name)
26+
if rank != src:
27+
if Version(torch.__version__) >= Version("2.3.0"):
28+
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
29+
else:
30+
obj = c10d._tensor_to_object(obj, size_tensor.item())
31+
return obj
32+
33+
34+
def ray_broadcast_tensor_dict(
35+
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
36+
) -> Dict[str, torch.Tensor]:
37+
rank = cc.get_rank(group_name)
38+
if rank == src:
39+
metadata = []
40+
for k, v in tensor_dict.items():
41+
metadata.append((k, v.shape, v.dtype))
42+
else:
43+
metadata = None
44+
metadata = ray_broadcast_object(metadata, src, device, group_name)
45+
if rank != src:
46+
out_dict = {}
47+
for k, shape, dtype in metadata:
48+
if rank == src:
49+
tensor = tensor_dict[k]
50+
else:
51+
tensor = torch.empty(shape, dtype=dtype, device=device)
52+
cc.broadcast(tensor, src, group_name)
53+
if rank != src:
54+
out_dict[k] = tensor
55+
if rank == src:
56+
out_dict = tensor_dict
57+
return out_dict
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from contextlib import nullcontext
2+
from typing import Any, Dict, Optional
3+
4+
import ray
5+
import ray.util.collective as cc
6+
import torch
7+
import torch.distributed as dist
8+
from tqdm import tqdm
9+
from transformers import AutoModelForCausalLM
10+
11+
from colossalai.booster import Booster
12+
from colossalai.booster.plugin import HybridParallelPlugin
13+
from colossalai.initialize import launch
14+
from colossalai.nn.optimizer import HybridAdam
15+
from colossalai.utils import get_current_device
16+
17+
from .comm import ray_broadcast_tensor_dict
18+
from .utils import bind_batch, post_recv, unbind_batch
19+
20+
21+
class BaseConsumer:
22+
def __init__(
23+
self,
24+
num_producers: int,
25+
num_episodes: int,
26+
rank: int,
27+
world_size: int,
28+
master_addr: str,
29+
master_port: int,
30+
num_update_per_episode: int,
31+
num_recv_per_update: int,
32+
batch_size: int,
33+
model_config: Dict[str, Any],
34+
plugin_config: Dict[str, Any],
35+
microbatch_size: int = 1,
36+
):
37+
self.num_producers = num_producers
38+
self.num_episodes = num_episodes
39+
self.rank = rank
40+
self.world_size = world_size
41+
self.master_addr = master_addr
42+
self.master_port = master_port
43+
self.num_update_per_episode = num_update_per_episode
44+
self.num_recv_per_update = num_recv_per_update
45+
self.batch_size = batch_size
46+
self.microbatch_size = microbatch_size
47+
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
48+
self.num_microbatches = batch_size // microbatch_size
49+
50+
self.model_config = model_config
51+
self.plugin_config = plugin_config
52+
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
53+
54+
self.device = get_current_device()
55+
56+
def setup(self) -> None:
57+
for i in range(self.num_producers):
58+
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
59+
if self.rank == 0:
60+
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
61+
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
62+
63+
plugin_config = dict(
64+
tp_size=1,
65+
pp_size=1,
66+
precision="bf16",
67+
zero_stage=1,
68+
)
69+
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
70+
plugin_config["microbatch_size"] = self.microbatch_size
71+
plugin_config.update(self.plugin_config)
72+
self.plugin = HybridParallelPlugin(**plugin_config)
73+
self.booster = Booster(plugin=self.plugin)
74+
self.dp_rank = dist.get_rank(self.plugin.dp_group)
75+
self.dp_size = dist.get_world_size(self.plugin.dp_group)
76+
77+
self.buffer = []
78+
79+
self.recv_cnt = 0
80+
81+
def state_dict(self) -> Dict[str, torch.Tensor]:
82+
raise NotImplementedError
83+
84+
def step(self, step_idx: int, **kwargs) -> Optional[float]:
85+
raise NotImplementedError
86+
87+
def loop(self) -> None:
88+
print(
89+
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
90+
)
91+
for episode in range(self.num_episodes):
92+
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
93+
for step in pbar:
94+
i = 0
95+
for _ in range(self.num_recv_per_update):
96+
# receive data from producers
97+
98+
for r in range(self.num_producers):
99+
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
100+
self.buffer.extend(
101+
unbind_batch(
102+
ray_broadcast_tensor_dict(
103+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
104+
)
105+
)
106+
)
107+
while len(self.buffer) >= self.dp_size * self.microbatch_size:
108+
batches = self.buffer[
109+
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
110+
]
111+
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
112+
batch = bind_batch(batches)
113+
batch = post_recv(batch)
114+
loss = self.step(i, **batch)
115+
if loss is not None:
116+
pbar.set_postfix({"loss": loss})
117+
i += 1
118+
assert len(self.buffer) == 0
119+
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
120+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
121+
state_dict = self.state_dict()
122+
if self.rank == 0:
123+
ray_broadcast_tensor_dict(
124+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
125+
)
126+
127+
128+
@ray.remote
129+
class SimpleConsumer(BaseConsumer):
130+
def __init__(
131+
self,
132+
num_producers,
133+
num_episodes,
134+
rank,
135+
world_size,
136+
master_addr,
137+
master_port,
138+
num_update_per_episode,
139+
num_recv_per_update,
140+
batch_size,
141+
model_config,
142+
plugin_config,
143+
microbatch_size=1,
144+
):
145+
super().__init__(
146+
num_producers,
147+
num_episodes,
148+
rank,
149+
world_size,
150+
master_addr,
151+
master_port,
152+
num_update_per_episode,
153+
num_recv_per_update,
154+
batch_size,
155+
model_config,
156+
plugin_config,
157+
microbatch_size,
158+
)
159+
path = model_config.pop("path")
160+
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
161+
self.model.train()
162+
self.model.gradient_checkpointing_enable()
163+
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
164+
self.accum_loss = torch.zeros(1, device=self.device)
165+
166+
def setup(self):
167+
super().setup()
168+
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
169+
170+
def step(self, step_idx: int, **kwargs) -> Optional[float]:
171+
need_update = (step_idx + 1) % self.num_microbatches == 0
172+
173+
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
174+
with ctx:
175+
out = self.model(**kwargs)
176+
loss = out.loss / self.num_microbatches
177+
self.accum_loss.add_(loss.data)
178+
self.booster.backward(loss, self.optimizer)
179+
if need_update:
180+
self.optimizer.step()
181+
self.optimizer.zero_grad()
182+
loss_scalar = self.accum_loss.item()
183+
self.accum_loss.zero_()
184+
return loss_scalar
185+
186+
def state_dict(self):
187+
self.model._force_wait_all_gather()
188+
model = self.model.unwrap()
189+
state_dict = model.state_dict()
190+
return state_dict

0 commit comments

Comments
 (0)