ColossalAI/applications/ColossalChat/coati/distributed/inference_backend.py

245 lines
10 KiB
Python

from typing import Any, Dict
import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
from colossalai.utils import get_current_device
from .utils import log_probs_from_logits, update_by_default
try:
import sglang as sgl
except ImportError:
sgl = None
try:
from vllm import LLM, SamplingParams
except ImportError:
LLM = None
class BaseInferenceBackend:
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
pass
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""Generate new tokens given input_ids and attention_mask.
Args:
input_ids (torch.Tensor): shape [B, S]
attention_mask (torch.Tensor): shape [B, S]
Returns:
Dict[str, torch.Tensor]: containing the
- input_ids (torch.Tensor): shape [B, S+N]
- attention_mask (torch.Tensor): shape [B, S+N]
- action_log_probs (torch.Tensor): shape [B, N]
- action_mask (torch.Tensor): shape [B, N]
where N is the number of generated tokens. And all tensors should be on CUDA.
"""
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
pass
class TransformersInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
FORCE_MODEL_CONFIG = dict(
device_map="auto",
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config["tokenizer"] = tokenizer
self.tokenizer = tokenizer
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]
# get log probs
assert new_token_ids.shape[-1] == len(out.logits)
action_log_probs = []
for i, logits in enumerate(out.logits):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
response_idx[:, 0] = input_len
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = {
"input_ids": out.sequences,
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.model.load_state_dict(state_dict)
class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
defaut_config = dict(
trust_remote_code=True,
skip_tokenizer_init=True,
)
defaut_config.update(model_config)
self.llm = sgl.Engine(model_path=path, **defaut_config)
self.generate_config = generate_config
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
out_tokens = []
out_len = []
for out in outputs:
out_tokens.append(out["token_ids"])
out_len.append(out["meta_info"]["completion_tokens"])
max_len = max(out_len)
input_len = input_ids.shape[-1]
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
for i in range(len(out_tokens)):
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
attention_mask[i, input_len + out_len[i] :] = 0
out = torch.tensor(out_tokens)
out = torch.cat((input_ids, out), dim=1)
labels = out.clone()
labels[..., :input_len] = -100
for i in range(len(out_len)):
labels[i, input_len + out_len[i] :] = -100
data = {
"input_ids": out,
"attention_mask": attention_mask,
"labels": labels,
}
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
if self.config.tie_word_embeddings:
del state_dict["lm_head.weight"]
named_tensors = [(k, v) for k, v in state_dict.items()]
self.llm.update_weights_from_tensor(named_tensors)
class VLLMInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = generate_config["n"]
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
log_probs = []
response_idx = []
for out in outputs:
for output_i in out.outputs:
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
# pad them
max_len = max(out_len)
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
for i, new_token_ids in enumerate(out_tokens):
pad_len = max_len - out_len[i]
out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
log_probs[i] = log_probs[i] + [0.0] * pad_len
action_mask[i, out_len[i] :] = 0
out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs)
response_idx = torch.tensor(response_idx)
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0)
attention_mask = attention_mask.repeat_interleave(num_returns, dim=0)
input_ids = input_ids.repeat_interleave(num_returns, dim=0)
out_tokens = torch.cat((input_ids, out_tokens), dim=1)
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = {
"input_ids": out_tokens,
"attention_mask": attention_mask,
"action_log_probs": log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(state_dict.items())
BACKEND_MAP = {
"transformers": TransformersInferenceBackend,
# "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason
"vllm": VLLMInferenceBackend,
}