From 43c9b5fb4408c5ce7f9dd19d1eed18737489b4e0 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 21 Feb 2025 15:24:23 +0800 Subject: [PATCH] [chat] add distributed impl (#6210) --- .../ColossalChat/coati/distributed/README.md | 6 + .../coati/distributed/__init__.py | 0 .../ColossalChat/coati/distributed/comm.py | 57 ++++++ .../coati/distributed/consumer.py | 190 ++++++++++++++++++ .../coati/distributed/inference_backend.py | 164 +++++++++++++++ .../ColossalChat/coati/distributed/launch.py | 87 ++++++++ .../coati/distributed/producer.py | 160 +++++++++++++++ .../ColossalChat/coati/distributed/utils.py | 40 ++++ applications/ColossalChat/rl_example.py | 94 +++++++++ 9 files changed, 798 insertions(+) create mode 100644 applications/ColossalChat/coati/distributed/README.md create mode 100644 applications/ColossalChat/coati/distributed/__init__.py create mode 100644 applications/ColossalChat/coati/distributed/comm.py create mode 100644 applications/ColossalChat/coati/distributed/consumer.py create mode 100644 applications/ColossalChat/coati/distributed/inference_backend.py create mode 100644 applications/ColossalChat/coati/distributed/launch.py create mode 100644 applications/ColossalChat/coati/distributed/producer.py create mode 100644 applications/ColossalChat/coati/distributed/utils.py create mode 100644 applications/ColossalChat/rl_example.py diff --git a/applications/ColossalChat/coati/distributed/README.md b/applications/ColossalChat/coati/distributed/README.md new file mode 100644 index 000000000..b7bac2b2d --- /dev/null +++ b/applications/ColossalChat/coati/distributed/README.md @@ -0,0 +1,6 @@ +# Requirements + +```bash +pip install cupy-cuda12x +python -m cupyx.tools.install_library --cuda 12.x --library nccl +``` diff --git a/applications/ColossalChat/coati/distributed/__init__.py b/applications/ColossalChat/coati/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py new file mode 100644 index 000000000..3824303f5 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -0,0 +1,57 @@ +from typing import Any, Dict + +import ray.util.collective as cc +import torch +import torch.distributed.distributed_c10d as c10d +from packaging.version import Version + + +def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any: + rank = cc.get_rank(group_name) + if rank == src: + if Version(torch.__version__) >= Version("2.3.0"): + obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None) + elif Version(torch.__version__) >= Version("1.13.0"): + obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device) + else: + obj_tensor, size_tensor = c10d._object_to_tensor(obj) + obj_tensor = obj_tensor.to(device) + size_tensor = size_tensor.to(device) + else: + size_tensor = torch.empty(1, dtype=torch.int64, device=device) + cc.broadcast(size_tensor, src, group_name) + if rank != src: + obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device) + cc.broadcast(obj_tensor, src, group_name) + if rank != src: + if Version(torch.__version__) >= Version("2.3.0"): + obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None) + else: + obj = c10d._tensor_to_object(obj, size_tensor.item()) + return obj + + +def ray_broadcast_tensor_dict( + tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default" +) -> Dict[str, torch.Tensor]: + rank = cc.get_rank(group_name) + if rank == src: + metadata = [] + for k, v in tensor_dict.items(): + metadata.append((k, v.shape, v.dtype)) + else: + metadata = None + metadata = ray_broadcast_object(metadata, src, device, group_name) + if rank != src: + out_dict = {} + for k, shape, dtype in metadata: + if rank == src: + tensor = tensor_dict[k] + else: + tensor = torch.empty(shape, dtype=dtype, device=device) + cc.broadcast(tensor, src, group_name) + if rank != src: + out_dict[k] = tensor + if rank == src: + out_dict = tensor_dict + return out_dict diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py new file mode 100644 index 000000000..61417f7e6 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -0,0 +1,190 @@ +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import ray +import ray.util.collective as cc +import torch +import torch.distributed as dist +from tqdm import tqdm +from transformers import AutoModelForCausalLM + +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.initialize import launch +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +from .comm import ray_broadcast_tensor_dict +from .utils import bind_batch, post_recv, unbind_batch + + +class BaseConsumer: + def __init__( + self, + num_producers: int, + num_episodes: int, + rank: int, + world_size: int, + master_addr: str, + master_port: int, + num_update_per_episode: int, + num_recv_per_update: int, + batch_size: int, + model_config: Dict[str, Any], + plugin_config: Dict[str, Any], + microbatch_size: int = 1, + ): + self.num_producers = num_producers + self.num_episodes = num_episodes + self.rank = rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + self.num_update_per_episode = num_update_per_episode + self.num_recv_per_update = num_recv_per_update + self.batch_size = batch_size + self.microbatch_size = microbatch_size + assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" + self.num_microbatches = batch_size // microbatch_size + + self.model_config = model_config + self.plugin_config = plugin_config + assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" + + self.device = get_current_device() + + def setup(self) -> None: + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") + if self.rank == 0: + cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") + launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) + + plugin_config = dict( + tp_size=1, + pp_size=1, + precision="bf16", + zero_stage=1, + ) + if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: + plugin_config["microbatch_size"] = self.microbatch_size + plugin_config.update(self.plugin_config) + self.plugin = HybridParallelPlugin(**plugin_config) + self.booster = Booster(plugin=self.plugin) + self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.dp_size = dist.get_world_size(self.plugin.dp_group) + + self.buffer = [] + + self.recv_cnt = 0 + + def state_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + raise NotImplementedError + + def loop(self) -> None: + print( + f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" + ) + for episode in range(self.num_episodes): + with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: + for step in pbar: + i = 0 + for _ in range(self.num_recv_per_update): + # receive data from producers + + for r in range(self.num_producers): + print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") + self.buffer.extend( + unbind_batch( + ray_broadcast_tensor_dict( + None, src=0, device=self.device, group_name=f"sync_data_{r}" + ) + ) + ) + while len(self.buffer) >= self.dp_size * self.microbatch_size: + batches = self.buffer[ + self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size + ] + self.buffer = self.buffer[self.dp_size * self.microbatch_size :] + batch = bind_batch(batches) + batch = post_recv(batch) + loss = self.step(i, **batch) + if loss is not None: + pbar.set_postfix({"loss": loss}) + i += 1 + assert len(self.buffer) == 0 + if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + state_dict = self.state_dict() + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + + +@ray.remote +class SimpleConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.model.train() + self.model.gradient_checkpointing_enable() + self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) + self.accum_loss = torch.zeros(1, device=self.device) + + def setup(self): + super().setup() + self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + need_update = (step_idx + 1) % self.num_microbatches == 0 + + ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer) + with ctx: + out = self.model(**kwargs) + loss = out.loss / self.num_microbatches + self.accum_loss.add_(loss.data) + self.booster.backward(loss, self.optimizer) + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + loss_scalar = self.accum_loss.item() + self.accum_loss.zero_() + return loss_scalar + + def state_dict(self): + self.model._force_wait_all_gather() + model = self.model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py new file mode 100644 index 000000000..d40808ab4 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -0,0 +1,164 @@ +from typing import Any, Dict + +import torch +import torch.nn.functional as F +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer + +from colossalai.utils import get_current_device + +try: + import sglang as sgl +except ImportError: + sgl = None + +try: + from vllm import LLM, SamplingParams +except ImportError: + LLM = None + + +class BaseInferenceBackend: + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + pass + + def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + pass + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + pass + + +class TransformersInferenceBackend(BaseInferenceBackend): + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + path = model_config.pop("path") + defaut_config = dict( + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + defaut_config.update(model_config) + self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config) + self.generate_config = generate_config + + def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + input_ids = input_ids.to(get_current_device()) + attention_mask = attention_mask.to(get_current_device()) + out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) + input_len = input_ids.shape[-1] + labels = out.clone() + labels[..., :input_len] = -100 + attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1) + attention_mask = attention_mask.expand_as(labels) + data = { + "input_ids": out, + "attention_mask": attention_mask, + "labels": labels, + } + return data + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.model.load_state_dict(state_dict) + + +class SGLangInferenceBackend(BaseInferenceBackend): + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + if sgl is None: + raise ImportError("sglang is not installed") + path = model_config.pop("path") + defaut_config = dict( + trust_remote_code=True, + skip_tokenizer_init=True, + ) + defaut_config.update(model_config) + self.llm = sgl.Engine(model_path=path, **defaut_config) + self.generate_config = generate_config + self.tokenizer = tokenizer + self.config = AutoConfig.from_pretrained(path) + + def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config) + out_tokens = [] + out_len = [] + for out in outputs: + out_tokens.append(out["token_ids"]) + out_len.append(out["meta_info"]["completion_tokens"]) + max_len = max(out_len) + input_len = input_ids.shape[-1] + attention_mask = F.pad(attention_mask, (0, max_len), value=1) + for i in range(len(out_tokens)): + out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i]) + attention_mask[i, input_len + out_len[i] :] = 0 + out = torch.tensor(out_tokens) + out = torch.cat((input_ids, out), dim=1) + labels = out.clone() + labels[..., :input_len] = -100 + for i in range(len(out_len)): + labels[i, input_len + out_len[i] :] = -100 + data = { + "input_ids": out, + "attention_mask": attention_mask, + "labels": labels, + } + data = {k: v.to(get_current_device()) for k, v in data.items()} + return data + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + if self.config.tie_word_embeddings: + del state_dict["lm_head.weight"] + named_tensors = [(k, v) for k, v in state_dict.items()] + self.llm.update_weights_from_tensor(named_tensors) + + +class VLLMInferenceBackend(BaseInferenceBackend): + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + if LLM is None: + raise ImportError("vllm is not installed") + path = model_config.pop("path") + defaut_config = dict( + trust_remote_code=True, + # skip_tokenizer_init=True, + ) + defaut_config.update(model_config) + self.llm = LLM(path, **defaut_config) + self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id]) + self.tokenizer = tokenizer + self.config = AutoConfig.from_pretrained(path) + + def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + outputs = self.llm.generate( + prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False + ) + out_tokens = [] + out_len = [] + for out in outputs: + out_tokens.append(list(out.outputs[0].token_ids)) + out_len.append(len(out.outputs[0].token_ids)) + max_len = max(out_len) + input_len = input_ids.shape[-1] + attention_mask = F.pad(attention_mask, (0, max_len), value=1) + for i in range(len(out_tokens)): + out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i]) + attention_mask[i, input_len + out_len[i] :] = 0 + out = torch.tensor(out_tokens) + out = torch.cat((input_ids, out), dim=1) + labels = out.clone() + labels[..., :input_len] = -100 + for i in range(len(out_len)): + labels[i, input_len + out_len[i] :] = -100 + data = { + "input_ids": out, + "attention_mask": attention_mask, + "labels": labels, + } + data = {k: v.to(get_current_device()) for k, v in data.items()} + return data + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items()) + + +BACKEND_MAP = { + "transformers": TransformersInferenceBackend, + "sglang": SGLangInferenceBackend, + "vllm": VLLMInferenceBackend, +} diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py new file mode 100644 index 000000000..438c46300 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Optional + +import ray + +from .consumer import SimpleConsumer +from .producer import SimpleProducer + + +def get_jsonl_size_fast(path: str) -> int: + with open(path) as f: + lines = f.readlines() + lines = [line for line in lines if line.strip()] + return len(lines) - 1 + + +def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: + tp_size = plugin_config.get("tp_size", 1) + pp_size = plugin_config.get("pp_size", 1) + ep_size = plugin_config.get("ep_size", 1) + sp_size = plugin_config.get("sp_size", 1) + return n_procs // (tp_size * pp_size * ep_size * sp_size) + + +def launch_distributed( + num_producers: int, + num_proc_per_producer: int, + num_consumer_procs: int, + num_episodes: int, + inference_batch_size: int, + inference_microbatch_size: int, + train_batch_size: int, + train_microbatch_size: int, + dataset_config: Dict[str, Any], + dataloaders_config: Dict[str, Any], + inference_model_config: Dict[str, Any], + generate_config: Dict[str, Any], + train_model_config: Dict[str, Any], + plugin_config: Dict[str, Any], + tokenizer_config: Optional[Dict[str, Any]] = None, + inference_backend: str = "transformers", + master_addr: str = "localhost", + master_port: int = 29500, +): + train_dp_size = get_dp_size_fast(num_producers, plugin_config) + assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 + + dataset_path = dataset_config["path"] + num_samples = get_jsonl_size_fast(dataset_path) + global_inference_batch_size = inference_batch_size * num_producers + num_update_per_episode = num_samples // global_inference_batch_size + num_recv_per_update = inference_batch_size // inference_microbatch_size + + procs = [] + for i in range(num_producers): + producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( + producer_idx=i, + num_producers=num_producers, + num_consumer_procs=num_consumer_procs, + num_episodes=num_episodes, + batch_size=inference_batch_size, + dataset_config=dataset_config, + dataloaders_config=dataloaders_config, + model_config=inference_model_config, + generate_config=generate_config, + tokenizer_config=tokenizer_config, + microbatch_size=inference_microbatch_size, + backend=inference_backend, + ) + procs.append(producer) + for i in range(num_consumer_procs): + consumer = SimpleConsumer.options(num_gpus=1).remote( + num_producers=num_producers, + num_episodes=num_episodes, + rank=i, + world_size=num_consumer_procs, + master_addr=master_addr, + master_port=master_port, + num_update_per_episode=num_update_per_episode, + num_recv_per_update=num_recv_per_update, + batch_size=train_batch_size, + model_config=train_model_config, + plugin_config=plugin_config, + microbatch_size=train_microbatch_size, + ) + procs.append(consumer) + ray.get([p.setup.remote() for p in procs]) + ray.get([p.loop.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py new file mode 100644 index 000000000..3e4a5277a --- /dev/null +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -0,0 +1,160 @@ +from typing import Any, Dict, Optional + +import ray +import ray.util.collective as cc +import torch +from coati.dataset.loader import RawConversationDataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoTokenizer + +from colossalai.utils import get_current_device + +from .comm import ray_broadcast_tensor_dict +from .inference_backend import BACKEND_MAP +from .utils import pre_send + + +class BaseProducer: + def __init__( + self, + producer_idx: int, + num_producers: int, + num_consumer_procs: int, + num_episodes: int, + batch_size: int, + dataset_config: Dict[str, Any], + dataloaders_config: Dict[str, Any], + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer_config: Optional[Dict[str, Any]] = None, + microbatch_size: int = 1, + backend: str = "transformers", + ): + self.producer_idx = producer_idx + self.num_producers = num_producers + self.num_consumer_procs = num_consumer_procs + self.num_episodes = num_episodes + self.batch_size = batch_size + self.microbatch_size = microbatch_size + assert batch_size % microbatch_size == 0 + self.num_microbatches = batch_size // microbatch_size + + self.dataset_config = dataset_config + self.model_config = model_config + self.generate_config = generate_config + self.tokenizer_config = tokenizer_config + + # init tokenizer + if tokenizer_config is None: + tokenizer_path = model_config["path"] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + else: + tokenizer_path = tokenizer_config.pop("path") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) + self.tokenizer.padding_side = "left" + + # init dataloader + dataset_path = dataset_config.pop("path") + self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) + self.dataloader = DataLoader( + self.dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + self.dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=True, + drop_last=True, + seed=42, + ), + num_workers=4, + ) + self.device = get_current_device() + + # init backend + if backend in BACKEND_MAP: + self.backend_cls = BACKEND_MAP[backend] + else: + raise ValueError(f"Unexpected backend {backend}") + + def setup(self) -> None: + cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") + cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + + def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError + + def loop(self) -> None: + num_update_per_episode = len(self.dataloader) // self.num_microbatches + num_valid_microbatches = num_update_per_episode * self.num_microbatches + + print( + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" + ) + for episode in range(self.num_episodes): + self.dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.dataloader): + if i >= num_valid_microbatches: + break + outputs = self.rollout(**batch) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs = pre_send(outputs) + ray_broadcast_tensor_dict( + outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" + ) + + if (i + 1) % self.num_microbatches == 0 and ( + episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 + ): + # don't sync model for last iteration + print( + f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" + ) + state_dict = ray_broadcast_tensor_dict( + None, self.num_producers, device=self.device, group_name="sync_model" + ) + self.load_state_dict(state_dict) + + +@ray.remote +class SimpleProducer(BaseProducer): + def __init__( + self, + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + dataset_config, + dataloaders_config, + model_config, + generate_config, + tokenizer_config=None, + microbatch_size=1, + backend="transformers", + ): + super().__init__( + producer_idx, + num_producers, + num_consumer_procs, + num_episodes, + batch_size, + dataset_config, + dataloaders_config, + model_config, + generate_config, + tokenizer_config, + microbatch_size, + backend, + ) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer) + + @torch.no_grad() + def rollout(self, input_ids, attention_mask, **kwargs): + return self.model.generate(input_ids, attention_mask, **kwargs) + + def load_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py new file mode 100644 index 000000000..2f3267a1f --- /dev/null +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -0,0 +1,40 @@ +from typing import Dict, List + +import torch + + +def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: + batches = [] + for k, v in batch.items(): + if len(batches) == 0: + unbinded_tensors = v.unbind(0) + batches = [{k: tensor} for tensor in unbinded_tensors] + else: + unbinded_tensors = v.unbind(0) + assert len(batches) == len(unbinded_tensors) + for i, tensor in enumerate(unbinded_tensors): + batches[i][k] = tensor + return batches + + +def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + batch = {} + for k in batches[0].keys(): + batch[k] = torch.stack([batch[k] for batch in batches], dim=0) + return batch + + +def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + # compress attention_mask to save bandwidth + if "attention_mask" in batch: + attention_mask = batch["attention_mask"] + batch["attention_mask"] = attention_mask.to(torch.bool) + return batch + + +def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + # decompress attention_mask + if "attention_mask" in batch: + attention_mask = batch["attention_mask"] + batch["attention_mask"] = attention_mask.to(torch.int) + return batch diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py new file mode 100644 index 000000000..a6f82b3be --- /dev/null +++ b/applications/ColossalChat/rl_example.py @@ -0,0 +1,94 @@ +import argparse + +import ray +import torch +from coati.distributed.launch import launch_distributed + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument("-t", "--num-trainers", type=int, default=2) + parser.add_argument("-i", "--num-inferencer", type=int, default=2) + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) + parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) + parser.add_argument("-b", "--backend", type=str, default="transformers") + args = parser.parse_args() + + ray.init(address="local", namespace="ray-example") + + inference_model_config = dict(path=args.model) + train_model_config = dict(path=args.model) + generate_config = dict( + top_k=50, + top_p=0.8, + ) + + if args.backend == "transformers": + inference_model_config.update( + dict( + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + ) + ) + train_model_config.update( + dict( + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + use_cache=False, + ) + ) + generate_config.update( + dict( + max_length=512, + do_sample=True, + max_new_tokens=None, + early_stopping=False, + ) + ) + elif args.backend == "vllm": + inference_model_config.update( + dict( + gpu_memory_utilization=0.6, + ) + ) + generate_config.update( + dict( + max_tokens=256, + ignore_eos=True, + ) + ) + else: + inference_model_config.update( + dict( + mem_fraction_static=0.6, + ) + ) + generate_config.update( + dict( + max_new_tokens=256, + ignore_eos=True, + ) + ) + + launch_distributed( + num_producers=args.num_inferencer, + num_proc_per_producer=1, + num_consumer_procs=args.num_trainers, + num_episodes=1, + inference_batch_size=args.inference_batch_size, + inference_microbatch_size=args.inference_microbatch_size, + train_batch_size=args.train_batch_size, + train_microbatch_size=args.train_microbatch_size, + dataset_config={"path": args.dataset, "max_length": 256}, + dataloaders_config={}, + inference_model_config=inference_model_config, + generate_config=generate_config, + train_model_config=train_model_config, + plugin_config={}, + inference_backend=args.backend, + master_addr="localhost", + master_port=29504, + )