mirror of https://github.com/hpcaitech/ColossalAI
[chat] add distributed impl (#6210)
parent
9379cbd668
commit
43c9b5fb44
|
@ -0,0 +1,6 @@
|
|||
# Requirements
|
||||
|
||||
```bash
|
||||
pip install cupy-cuda12x
|
||||
python -m cupyx.tools.install_library --cuda 12.x --library nccl
|
||||
```
|
|
@ -0,0 +1,57 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
|
||||
rank = cc.get_rank(group_name)
|
||||
if rank == src:
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
|
||||
elif Version(torch.__version__) >= Version("1.13.0"):
|
||||
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
|
||||
else:
|
||||
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
|
||||
obj_tensor = obj_tensor.to(device)
|
||||
size_tensor = size_tensor.to(device)
|
||||
else:
|
||||
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
|
||||
cc.broadcast(size_tensor, src, group_name)
|
||||
if rank != src:
|
||||
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
|
||||
cc.broadcast(obj_tensor, src, group_name)
|
||||
if rank != src:
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
|
||||
else:
|
||||
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
||||
return obj
|
||||
|
||||
|
||||
def ray_broadcast_tensor_dict(
|
||||
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
rank = cc.get_rank(group_name)
|
||||
if rank == src:
|
||||
metadata = []
|
||||
for k, v in tensor_dict.items():
|
||||
metadata.append((k, v.shape, v.dtype))
|
||||
else:
|
||||
metadata = None
|
||||
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
||||
if rank != src:
|
||||
out_dict = {}
|
||||
for k, shape, dtype in metadata:
|
||||
if rank == src:
|
||||
tensor = tensor_dict[k]
|
||||
else:
|
||||
tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
cc.broadcast(tensor, src, group_name)
|
||||
if rank != src:
|
||||
out_dict[k] = tensor
|
||||
if rank == src:
|
||||
out_dict = tensor_dict
|
||||
return out_dict
|
|
@ -0,0 +1,190 @@
|
|||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
def __init__(
|
||||
self,
|
||||
num_producers: int,
|
||||
num_episodes: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
num_update_per_episode: int,
|
||||
num_recv_per_update: int,
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_addr = master_addr
|
||||
self.master_port = master_port
|
||||
self.num_update_per_episode = num_update_per_episode
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||
|
||||
self.device = get_current_device()
|
||||
|
||||
def setup(self) -> None:
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
precision="bf16",
|
||||
zero_stage=1,
|
||||
)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
|
||||
self.buffer = []
|
||||
|
||||
self.recv_cnt = 0
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.buffer.extend(
|
||||
unbind_batch(
|
||||
ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
)
|
||||
)
|
||||
while len(self.buffer) >= self.dp_size * self.microbatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||
]
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, **batch)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
state_dict = self.state_dict()
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SimpleConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.model.train()
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
|
||||
with ctx:
|
||||
out = self.model(**kwargs)
|
||||
loss = out.loss / self.num_microbatches
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
self.accum_loss.zero_()
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
self.model._force_wait_all_gather()
|
||||
model = self.model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
|
@ -0,0 +1,164 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
try:
|
||||
import sglang as sgl
|
||||
except ImportError:
|
||||
sgl = None
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
except ImportError:
|
||||
LLM = None
|
||||
|
||||
|
||||
class BaseInferenceBackend:
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
pass
|
||||
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
path = model_config.pop("path")
|
||||
defaut_config = dict(
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
defaut_config.update(model_config)
|
||||
self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config)
|
||||
self.generate_config = generate_config
|
||||
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
input_ids = input_ids.to(get_current_device())
|
||||
attention_mask = attention_mask.to(get_current_device())
|
||||
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
||||
input_len = input_ids.shape[-1]
|
||||
labels = out.clone()
|
||||
labels[..., :input_len] = -100
|
||||
attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1)
|
||||
attention_mask = attention_mask.expand_as(labels)
|
||||
data = {
|
||||
"input_ids": out,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
path = model_config.pop("path")
|
||||
defaut_config = dict(
|
||||
trust_remote_code=True,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
defaut_config.update(model_config)
|
||||
self.llm = sgl.Engine(model_path=path, **defaut_config)
|
||||
self.generate_config = generate_config
|
||||
self.tokenizer = tokenizer
|
||||
self.config = AutoConfig.from_pretrained(path)
|
||||
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
for out in outputs:
|
||||
out_tokens.append(out["token_ids"])
|
||||
out_len.append(out["meta_info"]["completion_tokens"])
|
||||
max_len = max(out_len)
|
||||
input_len = input_ids.shape[-1]
|
||||
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
|
||||
for i in range(len(out_tokens)):
|
||||
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
|
||||
attention_mask[i, input_len + out_len[i] :] = 0
|
||||
out = torch.tensor(out_tokens)
|
||||
out = torch.cat((input_ids, out), dim=1)
|
||||
labels = out.clone()
|
||||
labels[..., :input_len] = -100
|
||||
for i in range(len(out_len)):
|
||||
labels[i, input_len + out_len[i] :] = -100
|
||||
data = {
|
||||
"input_ids": out,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
if self.config.tie_word_embeddings:
|
||||
del state_dict["lm_head.weight"]
|
||||
named_tensors = [(k, v) for k, v in state_dict.items()]
|
||||
self.llm.update_weights_from_tensor(named_tensors)
|
||||
|
||||
|
||||
class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
path = model_config.pop("path")
|
||||
defaut_config = dict(
|
||||
trust_remote_code=True,
|
||||
# skip_tokenizer_init=True,
|
||||
)
|
||||
defaut_config.update(model_config)
|
||||
self.llm = LLM(path, **defaut_config)
|
||||
self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id])
|
||||
self.tokenizer = tokenizer
|
||||
self.config = AutoConfig.from_pretrained(path)
|
||||
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
for out in outputs:
|
||||
out_tokens.append(list(out.outputs[0].token_ids))
|
||||
out_len.append(len(out.outputs[0].token_ids))
|
||||
max_len = max(out_len)
|
||||
input_len = input_ids.shape[-1]
|
||||
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
|
||||
for i in range(len(out_tokens)):
|
||||
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
|
||||
attention_mask[i, input_len + out_len[i] :] = 0
|
||||
out = torch.tensor(out_tokens)
|
||||
out = torch.cat((input_ids, out), dim=1)
|
||||
labels = out.clone()
|
||||
labels[..., :input_len] = -100
|
||||
for i in range(len(out_len)):
|
||||
labels[i, input_len + out_len[i] :] = -100
|
||||
data = {
|
||||
"input_ids": out,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())
|
||||
|
||||
|
||||
BACKEND_MAP = {
|
||||
"transformers": TransformersInferenceBackend,
|
||||
"sglang": SGLangInferenceBackend,
|
||||
"vllm": VLLMInferenceBackend,
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line for line in lines if line.strip()]
|
||||
return len(lines) - 1
|
||||
|
||||
|
||||
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
||||
tp_size = plugin_config.get("tp_size", 1)
|
||||
pp_size = plugin_config.get("pp_size", 1)
|
||||
ep_size = plugin_config.get("ep_size", 1)
|
||||
sp_size = plugin_config.get("sp_size", 1)
|
||||
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
||||
|
||||
|
||||
def launch_distributed(
|
||||
num_producers: int,
|
||||
num_proc_per_producer: int,
|
||||
num_consumer_procs: int,
|
||||
num_episodes: int,
|
||||
inference_batch_size: int,
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_microbatch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
train_model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
):
|
||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
dataset_path = dataset_config["path"]
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
|
||||
procs = []
|
||||
for i in range(num_producers):
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=inference_batch_size,
|
||||
dataset_config=dataset_config,
|
||||
dataloaders_config=dataloaders_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
)
|
||||
procs.append(producer)
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = SimpleConsumer.options(num_gpus=1).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
world_size=num_consumer_procs,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
num_update_per_episode=num_update_per_episode,
|
||||
num_recv_per_update=num_recv_per_update,
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_microbatch_size,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
ray.get([p.loop.remote() for p in procs])
|
|
@ -0,0 +1,160 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .inference_backend import BACKEND_MAP
|
||||
from .utils import pre_send
|
||||
|
||||
|
||||
class BaseProducer:
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx: int,
|
||||
num_producers: int,
|
||||
num_consumer_procs: int,
|
||||
num_episodes: int,
|
||||
batch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
microbatch_size: int = 1,
|
||||
backend: str = "transformers",
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
self.num_consumer_procs = num_consumer_procs
|
||||
self.num_episodes = num_episodes
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
self.dataset_config = dataset_config
|
||||
self.model_config = model_config
|
||||
self.generate_config = generate_config
|
||||
self.tokenizer_config = tokenizer_config
|
||||
|
||||
# init tokenizer
|
||||
if tokenizer_config is None:
|
||||
tokenizer_path = model_config["path"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
else:
|
||||
tokenizer_path = tokenizer_config.pop("path")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# init dataloader
|
||||
dataset_path = dataset_config.pop("path")
|
||||
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
||||
self.dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
seed=42,
|
||||
),
|
||||
num_workers=4,
|
||||
)
|
||||
self.device = get_current_device()
|
||||
|
||||
# init backend
|
||||
if backend in BACKEND_MAP:
|
||||
self.backend_cls = BACKEND_MAP[backend]
|
||||
else:
|
||||
raise ValueError(f"Unexpected backend {backend}")
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
print(
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
self.dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.dataloader):
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
outputs = self.rollout(**batch)
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
):
|
||||
# don't sync model for last iteration
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SimpleProducer(BaseProducer):
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
||||
batches = []
|
||||
for k, v in batch.items():
|
||||
if len(batches) == 0:
|
||||
unbinded_tensors = v.unbind(0)
|
||||
batches = [{k: tensor} for tensor in unbinded_tensors]
|
||||
else:
|
||||
unbinded_tensors = v.unbind(0)
|
||||
assert len(batches) == len(unbinded_tensors)
|
||||
for i, tensor in enumerate(unbinded_tensors):
|
||||
batches[i][k] = tensor
|
||||
return batches
|
||||
|
||||
|
||||
def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
batch = {}
|
||||
for k in batches[0].keys():
|
||||
batch[k] = torch.stack([batch[k] for batch in batches], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress attention_mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.bool)
|
||||
return batch
|
||||
|
||||
|
||||
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# decompress attention_mask
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.int)
|
||||
return batch
|
|
@ -0,0 +1,94 @@
|
|||
import argparse
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.distributed.launch import launch_distributed
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model)
|
||||
generate_config = dict(
|
||||
top_k=50,
|
||||
top_p=0.8,
|
||||
)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
train_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.6,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=256,
|
||||
ignore_eos=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
mem_fraction_static=0.6,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_new_tokens=256,
|
||||
ignore_eos=True,
|
||||
)
|
||||
)
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=1,
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=1,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_microbatch_size=args.train_microbatch_size,
|
||||
dataset_config={"path": args.dataset, "max_length": 256},
|
||||
dataloaders_config={},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
train_model_config=train_model_config,
|
||||
plugin_config={},
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29504,
|
||||
)
|
Loading…
Reference in New Issue