[chat] add distributed impl (#6210)

feature/ray-rlhf
Hongxin Liu 2025-02-21 15:24:23 +08:00 committed by GitHub
parent 9379cbd668
commit 43c9b5fb44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 798 additions and 0 deletions

View File

@ -0,0 +1,6 @@
# Requirements
```bash
pip install cupy-cuda12x
python -m cupyx.tools.install_library --cuda 12.x --library nccl
```

View File

@ -0,0 +1,57 @@
from typing import Any, Dict
import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
from packaging.version import Version
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
rank = cc.get_rank(group_name)
if rank == src:
if Version(torch.__version__) >= Version("2.3.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
elif Version(torch.__version__) >= Version("1.13.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
else:
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
obj_tensor = obj_tensor.to(device)
size_tensor = size_tensor.to(device)
else:
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
cc.broadcast(size_tensor, src, group_name)
if rank != src:
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
cc.broadcast(obj_tensor, src, group_name)
if rank != src:
if Version(torch.__version__) >= Version("2.3.0"):
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
else:
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj
def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if rank == src:
metadata = []
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
cc.broadcast(tensor, src, group_name)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict

View File

@ -0,0 +1,190 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
class BaseConsumer:
def __init__(
self,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
num_update_per_episode: int,
num_recv_per_update: int,
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
def setup(self) -> None:
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.buffer = []
self.recv_cnt = 0
def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError
def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
for episode in range(self.num_episodes):
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
)
)
while len(self.buffer) >= self.dp_size * self.microbatch_size:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
@ray.remote
class SimpleConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train()
self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
self.accum_loss = torch.zeros(1, device=self.device)
def setup(self):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
with ctx:
out = self.model(**kwargs)
loss = out.loss / self.num_microbatches
self.accum_loss.add_(loss.data)
self.booster.backward(loss, self.optimizer)
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
self.accum_loss.zero_()
return loss_scalar
def state_dict(self):
self.model._force_wait_all_gather()
model = self.model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -0,0 +1,164 @@
from typing import Any, Dict
import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
from colossalai.utils import get_current_device
try:
import sglang as sgl
except ImportError:
sgl = None
try:
from vllm import LLM, SamplingParams
except ImportError:
LLM = None
class BaseInferenceBackend:
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
pass
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
pass
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
pass
class TransformersInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
defaut_config.update(model_config)
self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config)
self.generate_config = generate_config
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
input_len = input_ids.shape[-1]
labels = out.clone()
labels[..., :input_len] = -100
attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1)
attention_mask = attention_mask.expand_as(labels)
data = {
"input_ids": out,
"attention_mask": attention_mask,
"labels": labels,
}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.model.load_state_dict(state_dict)
class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
skip_tokenizer_init=True,
)
defaut_config.update(model_config)
self.llm = sgl.Engine(model_path=path, **defaut_config)
self.generate_config = generate_config
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
out_tokens = []
out_len = []
for out in outputs:
out_tokens.append(out["token_ids"])
out_len.append(out["meta_info"]["completion_tokens"])
max_len = max(out_len)
input_len = input_ids.shape[-1]
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
for i in range(len(out_tokens)):
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
attention_mask[i, input_len + out_len[i] :] = 0
out = torch.tensor(out_tokens)
out = torch.cat((input_ids, out), dim=1)
labels = out.clone()
labels[..., :input_len] = -100
for i in range(len(out_len)):
labels[i, input_len + out_len[i] :] = -100
data = {
"input_ids": out,
"attention_mask": attention_mask,
"labels": labels,
}
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
if self.config.tie_word_embeddings:
del state_dict["lm_head.weight"]
named_tensors = [(k, v) for k, v in state_dict.items()]
self.llm.update_weights_from_tensor(named_tensors)
class VLLMInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if LLM is None:
raise ImportError("vllm is not installed")
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
# skip_tokenizer_init=True,
)
defaut_config.update(model_config)
self.llm = LLM(path, **defaut_config)
self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id])
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
for out in outputs:
out_tokens.append(list(out.outputs[0].token_ids))
out_len.append(len(out.outputs[0].token_ids))
max_len = max(out_len)
input_len = input_ids.shape[-1]
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
for i in range(len(out_tokens)):
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
attention_mask[i, input_len + out_len[i] :] = 0
out = torch.tensor(out_tokens)
out = torch.cat((input_ids, out), dim=1)
labels = out.clone()
labels[..., :input_len] = -100
for i in range(len(out_len)):
labels[i, input_len + out_len[i] :] = -100
data = {
"input_ids": out,
"attention_mask": attention_mask,
"labels": labels,
}
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())
BACKEND_MAP = {
"transformers": TransformersInferenceBackend,
"sglang": SGLangInferenceBackend,
"vllm": VLLMInferenceBackend,
}

View File

@ -0,0 +1,87 @@
from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .producer import SimpleProducer
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines) - 1
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
tp_size = plugin_config.get("tp_size", 1)
pp_size = plugin_config.get("pp_size", 1)
ep_size = plugin_config.get("ep_size", 1)
sp_size = plugin_config.get("sp_size", 1)
return n_procs // (tp_size * pp_size * ep_size * sp_size)
def launch_distributed(
num_producers: int,
num_proc_per_producer: int,
num_consumer_procs: int,
num_episodes: int,
inference_batch_size: int,
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
master_addr: str = "localhost",
master_port: int = 29500,
):
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
procs = []
for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
dataset_config=dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
)
procs.append(producer)
for i in range(num_consumer_procs):
consumer = SimpleConsumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
world_size=num_consumer_procs,
master_addr=master_addr,
master_port=master_port,
num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update,
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
ray.get([p.loop.remote() for p in procs])

View File

@ -0,0 +1,160 @@
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
from coati.dataset.loader import RawConversationDataset
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send
class BaseProducer:
def __init__(
self,
producer_idx: int,
num_producers: int,
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1,
backend: str = "transformers",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
self.num_consumer_procs = num_consumer_procs
self.num_episodes = num_episodes
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.dataset_config = dataset_config
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
# init tokenizer
if tokenizer_config is None:
tokenizer_path = model_config["path"]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
tokenizer_path = tokenizer_config.pop("path")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
# init dataloader
dataset_path = dataset_config.pop("path")
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
self.dataloader = DataLoader(
self.dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
self.dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=True,
drop_last=True,
seed=42,
),
num_workers=4,
)
self.device = get_current_device()
# init backend
if backend in BACKEND_MAP:
self.backend_cls = BACKEND_MAP[backend]
else:
raise ValueError(f"Unexpected backend {backend}")
def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError
def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
)
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
):
# don't sync model for last iteration
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
@ray.remote
class SimpleProducer(BaseProducer):
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
):
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
return self.model.generate(input_ids, attention_mask, **kwargs)
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@ -0,0 +1,40 @@
from typing import Dict, List
import torch
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
for k, v in batch.items():
if len(batches) == 0:
unbinded_tensors = v.unbind(0)
batches = [{k: tensor} for tensor in unbinded_tensors]
else:
unbinded_tensors = v.unbind(0)
assert len(batches) == len(unbinded_tensors)
for i, tensor in enumerate(unbinded_tensors):
batches[i][k] = tensor
return batches
def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
batch = {}
for k in batches[0].keys():
batch[k] = torch.stack([batch[k] for batch in batches], dim=0)
return batch
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress attention_mask to save bandwidth
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.bool)
return batch
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# decompress attention_mask
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.int)
return batch

View File

@ -0,0 +1,94 @@
import argparse
import ray
import torch
from coati.distributed.launch import launch_distributed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
args = parser.parse_args()
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
generate_config = dict(
top_k=50,
top_p=0.8,
)
if args.backend == "transformers":
inference_model_config.update(
dict(
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False,
)
)
generate_config.update(
dict(
max_length=512,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
)
)
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.6,
)
)
generate_config.update(
dict(
max_tokens=256,
ignore_eos=True,
)
)
else:
inference_model_config.update(
dict(
mem_fraction_static=0.6,
)
)
generate_config.update(
dict(
max_new_tokens=256,
ignore_eos=True,
)
)
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=1,
num_consumer_procs=args.num_trainers,
num_episodes=1,
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 256},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
train_model_config=train_model_config,
plugin_config={},
inference_backend=args.backend,
master_addr="localhost",
master_port=29504,
)