Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions applications/ColossalChat/coati/distributed/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Requirements

```bash
pip install cupy-cuda12x
python -m cupyx.tools.install_library --cuda 12.x --library nccl
```
Empty file.
57 changes: 57 additions & 0 deletions applications/ColossalChat/coati/distributed/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict

import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
from packaging.version import Version


def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
rank = cc.get_rank(group_name)
if rank == src:
if Version(torch.__version__) >= Version("2.3.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
elif Version(torch.__version__) >= Version("1.13.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
else:
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
obj_tensor = obj_tensor.to(device)
size_tensor = size_tensor.to(device)
else:
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
cc.broadcast(size_tensor, src, group_name)
if rank != src:
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
cc.broadcast(obj_tensor, src, group_name)
if rank != src:
if Version(torch.__version__) >= Version("2.3.0"):
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
else:
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj


def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if rank == src:
metadata = []
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
cc.broadcast(tensor, src, group_name)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict
190 changes: 190 additions & 0 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional

import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch


class BaseConsumer:
def __init__(
self,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
num_update_per_episode: int,
num_recv_per_update: int,
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()

def setup(self) -> None:
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)

self.buffer = []

self.recv_cnt = 0

def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError

def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
for episode in range(self.num_episodes):
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
)
)
while len(self.buffer) >= self.dp_size * self.microbatch_size:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)


@ray.remote
class SimpleConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train()
self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
self.accum_loss = torch.zeros(1, device=self.device)

def setup(self):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
with ctx:
out = self.model(**kwargs)
loss = out.loss / self.num_microbatches
self.accum_loss.add_(loss.data)
self.booster.backward(loss, self.optimizer)
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
self.accum_loss.zero_()
return loss_scalar

def state_dict(self):
self.model._force_wait_all_gather()
model = self.model.unwrap()
state_dict = model.state_dict()
return state_dict
Loading