|
| 1 | +from contextlib import nullcontext |
| 2 | +from typing import Any, Dict, Optional |
| 3 | + |
| 4 | +import ray |
| 5 | +import ray.util.collective as cc |
| 6 | +import torch |
| 7 | +import torch.distributed as dist |
| 8 | +from tqdm import tqdm |
| 9 | +from transformers import AutoModelForCausalLM |
| 10 | + |
| 11 | +from colossalai.booster import Booster |
| 12 | +from colossalai.booster.plugin import HybridParallelPlugin |
| 13 | +from colossalai.initialize import launch |
| 14 | +from colossalai.nn.optimizer import HybridAdam |
| 15 | +from colossalai.utils import get_current_device |
| 16 | + |
| 17 | +from .comm import ray_broadcast_tensor_dict |
| 18 | +from .utils import bind_batch, post_recv, unbind_batch |
| 19 | + |
| 20 | + |
| 21 | +class BaseConsumer: |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + num_producers: int, |
| 25 | + num_episodes: int, |
| 26 | + rank: int, |
| 27 | + world_size: int, |
| 28 | + master_addr: str, |
| 29 | + master_port: int, |
| 30 | + num_update_per_episode: int, |
| 31 | + num_recv_per_update: int, |
| 32 | + batch_size: int, |
| 33 | + model_config: Dict[str, Any], |
| 34 | + plugin_config: Dict[str, Any], |
| 35 | + microbatch_size: int = 1, |
| 36 | + ): |
| 37 | + self.num_producers = num_producers |
| 38 | + self.num_episodes = num_episodes |
| 39 | + self.rank = rank |
| 40 | + self.world_size = world_size |
| 41 | + self.master_addr = master_addr |
| 42 | + self.master_port = master_port |
| 43 | + self.num_update_per_episode = num_update_per_episode |
| 44 | + self.num_recv_per_update = num_recv_per_update |
| 45 | + self.batch_size = batch_size |
| 46 | + self.microbatch_size = microbatch_size |
| 47 | + assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" |
| 48 | + self.num_microbatches = batch_size // microbatch_size |
| 49 | + |
| 50 | + self.model_config = model_config |
| 51 | + self.plugin_config = plugin_config |
| 52 | + assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" |
| 53 | + |
| 54 | + self.device = get_current_device() |
| 55 | + |
| 56 | + def setup(self) -> None: |
| 57 | + for i in range(self.num_producers): |
| 58 | + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") |
| 59 | + if self.rank == 0: |
| 60 | + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") |
| 61 | + launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) |
| 62 | + |
| 63 | + plugin_config = dict( |
| 64 | + tp_size=1, |
| 65 | + pp_size=1, |
| 66 | + precision="bf16", |
| 67 | + zero_stage=1, |
| 68 | + ) |
| 69 | + if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: |
| 70 | + plugin_config["microbatch_size"] = self.microbatch_size |
| 71 | + plugin_config.update(self.plugin_config) |
| 72 | + self.plugin = HybridParallelPlugin(**plugin_config) |
| 73 | + self.booster = Booster(plugin=self.plugin) |
| 74 | + self.dp_rank = dist.get_rank(self.plugin.dp_group) |
| 75 | + self.dp_size = dist.get_world_size(self.plugin.dp_group) |
| 76 | + |
| 77 | + self.buffer = [] |
| 78 | + |
| 79 | + self.recv_cnt = 0 |
| 80 | + |
| 81 | + def state_dict(self) -> Dict[str, torch.Tensor]: |
| 82 | + raise NotImplementedError |
| 83 | + |
| 84 | + def step(self, step_idx: int, **kwargs) -> Optional[float]: |
| 85 | + raise NotImplementedError |
| 86 | + |
| 87 | + def loop(self) -> None: |
| 88 | + print( |
| 89 | + f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" |
| 90 | + ) |
| 91 | + for episode in range(self.num_episodes): |
| 92 | + with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: |
| 93 | + for step in pbar: |
| 94 | + i = 0 |
| 95 | + for _ in range(self.num_recv_per_update): |
| 96 | + # receive data from producers |
| 97 | + |
| 98 | + for r in range(self.num_producers): |
| 99 | + print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") |
| 100 | + self.buffer.extend( |
| 101 | + unbind_batch( |
| 102 | + ray_broadcast_tensor_dict( |
| 103 | + None, src=0, device=self.device, group_name=f"sync_data_{r}" |
| 104 | + ) |
| 105 | + ) |
| 106 | + ) |
| 107 | + while len(self.buffer) >= self.dp_size * self.microbatch_size: |
| 108 | + batches = self.buffer[ |
| 109 | + self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size |
| 110 | + ] |
| 111 | + self.buffer = self.buffer[self.dp_size * self.microbatch_size :] |
| 112 | + batch = bind_batch(batches) |
| 113 | + batch = post_recv(batch) |
| 114 | + loss = self.step(i, **batch) |
| 115 | + if loss is not None: |
| 116 | + pbar.set_postfix({"loss": loss}) |
| 117 | + i += 1 |
| 118 | + assert len(self.buffer) == 0 |
| 119 | + if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: |
| 120 | + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") |
| 121 | + state_dict = self.state_dict() |
| 122 | + if self.rank == 0: |
| 123 | + ray_broadcast_tensor_dict( |
| 124 | + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +@ray.remote |
| 129 | +class SimpleConsumer(BaseConsumer): |
| 130 | + def __init__( |
| 131 | + self, |
| 132 | + num_producers, |
| 133 | + num_episodes, |
| 134 | + rank, |
| 135 | + world_size, |
| 136 | + master_addr, |
| 137 | + master_port, |
| 138 | + num_update_per_episode, |
| 139 | + num_recv_per_update, |
| 140 | + batch_size, |
| 141 | + model_config, |
| 142 | + plugin_config, |
| 143 | + microbatch_size=1, |
| 144 | + ): |
| 145 | + super().__init__( |
| 146 | + num_producers, |
| 147 | + num_episodes, |
| 148 | + rank, |
| 149 | + world_size, |
| 150 | + master_addr, |
| 151 | + master_port, |
| 152 | + num_update_per_episode, |
| 153 | + num_recv_per_update, |
| 154 | + batch_size, |
| 155 | + model_config, |
| 156 | + plugin_config, |
| 157 | + microbatch_size, |
| 158 | + ) |
| 159 | + path = model_config.pop("path") |
| 160 | + self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) |
| 161 | + self.model.train() |
| 162 | + self.model.gradient_checkpointing_enable() |
| 163 | + self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) |
| 164 | + self.accum_loss = torch.zeros(1, device=self.device) |
| 165 | + |
| 166 | + def setup(self): |
| 167 | + super().setup() |
| 168 | + self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) |
| 169 | + |
| 170 | + def step(self, step_idx: int, **kwargs) -> Optional[float]: |
| 171 | + need_update = (step_idx + 1) % self.num_microbatches == 0 |
| 172 | + |
| 173 | + ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer) |
| 174 | + with ctx: |
| 175 | + out = self.model(**kwargs) |
| 176 | + loss = out.loss / self.num_microbatches |
| 177 | + self.accum_loss.add_(loss.data) |
| 178 | + self.booster.backward(loss, self.optimizer) |
| 179 | + if need_update: |
| 180 | + self.optimizer.step() |
| 181 | + self.optimizer.zero_grad() |
| 182 | + loss_scalar = self.accum_loss.item() |
| 183 | + self.accum_loss.zero_() |
| 184 | + return loss_scalar |
| 185 | + |
| 186 | + def state_dict(self): |
| 187 | + self.model._force_wait_all_gather() |
| 188 | + model = self.model.unwrap() |
| 189 | + state_dict = model.state_dict() |
| 190 | + return state_dict |
0 commit comments