mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* support vllm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify vllm and update readme * run pre-commit * remove dupilicated lines and refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update param name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine code * update readme * refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5493/merge
Camille Zhong
2 months ago
committed by
GitHub
19 changed files with 576 additions and 35 deletions
@ -1,5 +1,6 @@
|
||||
from .base import BaseModel |
||||
from .chatglm import ChatGLM2Model, ChatGLMModel |
||||
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel |
||||
from .vllm import vLLMModel |
||||
|
||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"] |
||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"] |
||||
|
@ -0,0 +1,498 @@
|
||||
import copy |
||||
from typing import Any, Dict, List, Optional, Tuple |
||||
|
||||
import numpy as np |
||||
import torch |
||||
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 |
||||
from torch.utils.data import DataLoader |
||||
from tqdm import tqdm |
||||
from vllm import LLM, SamplingParams |
||||
|
||||
from colossalai.logging import DistributedLogger |
||||
|
||||
from .huggingface import HuggingFaceModel |
||||
|
||||
IGNORE_INDEX = -100 |
||||
|
||||
|
||||
class vLLMModel(HuggingFaceModel): |
||||
""" |
||||
Model wrapper around vLLM models. |
||||
|
||||
Args: |
||||
path: The path to a vLLM model. |
||||
model_max_length: The maximum sequence length of the model. |
||||
tokenizer_path: The path to the tokenizer. |
||||
tokenizer_kwargs: Keyword arguments for the tokenizer. |
||||
model_kwargs: Keyword arguments for the model. |
||||
prompt_template: The model's prompt template. |
||||
batch_size: Batch size for inference. |
||||
logger: Logger for the model. |
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. |
||||
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. |
||||
quantization: The method used to quantize the model weights |
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. |
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space. |
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. |
||||
enforce_eager: Whether to enforce eager execution. |
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs. |
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. |
||||
disable_custom_all_reduce: See ParallelConfig |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
path: str, |
||||
model_max_length: int = 2048, |
||||
tokenizer_path: Optional[str] = None, |
||||
tokenizer_kwargs: Dict = None, |
||||
model_kwargs: Dict = None, |
||||
prompt_template: Conversation = None, |
||||
batch_size: int = 1, |
||||
logger: DistributedLogger = None, |
||||
trust_remote_code: bool = False, |
||||
tensor_parallel_size: int = 1, |
||||
quantization: Optional[str] = None, |
||||
gpu_memory_utilization: float = 0.5, |
||||
swap_space: float = 4, |
||||
cpu_offload_gb: float = 0, |
||||
enforce_eager: Optional[bool] = None, |
||||
max_context_len_to_capture: Optional[int] = None, |
||||
max_seq_len_to_capture: int = 8192, |
||||
disable_custom_all_reduce: bool = False, |
||||
**kwargs, |
||||
): |
||||
super().__init__( |
||||
path=path, |
||||
model_max_length=model_max_length, |
||||
prompt_template=prompt_template, |
||||
batch_size=batch_size, |
||||
logger=logger, |
||||
) |
||||
|
||||
self._load_model( |
||||
path=path, |
||||
model_kwargs=model_kwargs, |
||||
tokenizer_kwargs=tokenizer_kwargs, |
||||
tokenizer_path=tokenizer_path if tokenizer_path else None, |
||||
trust_remote_code=trust_remote_code, |
||||
tensor_parallel_size=tensor_parallel_size, |
||||
quantization=quantization, |
||||
gpu_memory_utilization=gpu_memory_utilization, |
||||
swap_space=swap_space, |
||||
cpu_offload_gb=cpu_offload_gb, |
||||
enforce_eager=enforce_eager, |
||||
max_context_len_to_capture=max_context_len_to_capture, |
||||
max_seq_len_to_capture=max_seq_len_to_capture, |
||||
disable_custom_all_reduce=disable_custom_all_reduce, |
||||
) |
||||
|
||||
def _load_model( |
||||
self, |
||||
path: str, |
||||
model_kwargs: dict, |
||||
tokenizer_kwargs: dict, |
||||
tokenizer_path: Optional[str] = None, |
||||
trust_remote_code: bool = False, |
||||
tensor_parallel_size: int = 1, |
||||
quantization: Optional[str] = None, |
||||
gpu_memory_utilization: float = 0.9, |
||||
swap_space: float = 4, |
||||
cpu_offload_gb: float = 0, |
||||
enforce_eager: Optional[bool] = None, |
||||
max_context_len_to_capture: Optional[int] = None, |
||||
max_seq_len_to_capture: int = 8192, |
||||
disable_custom_all_reduce: bool = False, |
||||
): |
||||
""" |
||||
Load model. |
||||
|
||||
Args: |
||||
path: The path to the model. |
||||
model_kwargs: Keyword arguments for the model. |
||||
tokenizer_kwargs: Keyword arguments for the tokenizer. |
||||
tokenizer_path: The path to the tokenizer. |
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. |
||||
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. |
||||
quantization: The method used to quantize the model weights |
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. |
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space. |
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. |
||||
enforce_eager: Whether to enforce eager execution. |
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs. |
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. |
||||
disable_custom_all_reduce: See ParallelConfig |
||||
|
||||
""" |
||||
if "torch_dtype" in model_kwargs: |
||||
model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) |
||||
model_kwargs.pop("torch_dtype") |
||||
else: |
||||
model_kwargs.setdefault("dtype", torch.float16) |
||||
|
||||
if "trust_remote_code" in model_kwargs: |
||||
trust_remote_code = model_kwargs["trust_remote_code"] |
||||
model_kwargs.pop("trust_remote_code") |
||||
|
||||
if "trust_remote_code" in tokenizer_kwargs: |
||||
trust_remote_code = tokenizer_kwargs["trust_remote_code"] |
||||
tokenizer_kwargs.pop("trust_remote_code") |
||||
|
||||
self.model = LLM( |
||||
model=path, |
||||
trust_remote_code=trust_remote_code, |
||||
tensor_parallel_size=tensor_parallel_size, |
||||
quantization=quantization, |
||||
gpu_memory_utilization=gpu_memory_utilization, |
||||
swap_space=swap_space, |
||||
cpu_offload_gb=cpu_offload_gb, |
||||
enforce_eager=enforce_eager, |
||||
max_context_len_to_capture=max_context_len_to_capture, |
||||
max_seq_len_to_capture=max_seq_len_to_capture, |
||||
disable_custom_all_reduce=disable_custom_all_reduce, |
||||
**model_kwargs, |
||||
**tokenizer_kwargs, |
||||
) |
||||
|
||||
self.tokenizer = self.model.get_tokenizer() |
||||
|
||||
if self.batch_size > 1: |
||||
self.tokenizer.padding_side = "left" |
||||
self.tokenizer.truncation_side = "left" |
||||
|
||||
if self.tokenizer.pad_token_id is None: |
||||
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") |
||||
if self.tokenizer.eos_token: |
||||
self.tokenizer.pad_token = self.tokenizer.eos_token |
||||
elif hasattr(self.tokenizer, "eod_id"): |
||||
# Qwen has an eod token "<|endoftext|>". |
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id |
||||
else: |
||||
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") |
||||
raise ValueError( |
||||
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. " |
||||
"Please set pad_token_id manually." |
||||
) |
||||
|
||||
def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: |
||||
""" |
||||
Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 |
||||
|
||||
Args: |
||||
input_ids_list: A batch of input string. |
||||
labels: A batch of labels. |
||||
|
||||
Returns: |
||||
A list of loss and a list of label length. |
||||
|
||||
""" |
||||
batch_size = len(inputs) |
||||
sampling_kwargs = SamplingParams(logprobs=1) |
||||
outputs = self.model.generate(inputs, sampling_kwargs) |
||||
ce_loss = [] |
||||
|
||||
if labels is not None: |
||||
lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] |
||||
else: |
||||
lens = [1] * batch_size |
||||
|
||||
for i in range(batch_size): |
||||
logprobs = outputs[i].outputs[0].logprobs |
||||
token_ids = outputs[i].outputs[0].token_ids |
||||
|
||||
logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] |
||||
logprobs_list = [i.logprob for i in logprobs_list] |
||||
logprobs_list = np.array(logprobs_list) |
||||
|
||||
if lens is not None: |
||||
logprobs_list = logprobs_list[: lens[i]] |
||||
|
||||
loss = -logprobs_list.sum(axis=-1) / lens[i] |
||||
ce_loss.append(loss) |
||||
|
||||
batch_loss = np.array(ce_loss) |
||||
|
||||
return batch_loss, lens |
||||
|
||||
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: |
||||
""" |
||||
Infer the given data. |
||||
This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits. |
||||
|
||||
Args: |
||||
data: The data for inference. |
||||
inference_kwargs: Arguments for inference. |
||||
debug: Whether to display generated prompt for debugging. |
||||
|
||||
Returns: |
||||
Inference results. |
||||
|
||||
""" |
||||
calculate_loss = inference_kwargs["calculate_loss"] |
||||
classes = inference_kwargs["all_classes"] |
||||
language = inference_kwargs["language"] |
||||
calculate_overall_loss = inference_kwargs["calculate_overall_loss"] |
||||
max_new_tokens = inference_kwargs["max_new_tokens"] |
||||
few_shot_data = inference_kwargs.get("few_shot_data", None) |
||||
|
||||
# Some classification questions' options are texts not a single letter such as A, B, C and D. |
||||
# If the text length is greater than 1, we won't calculate loss over choices. |
||||
if classes is not None and any(len(c) > 1 for c in classes): |
||||
classes = None |
||||
|
||||
self.choices = classes |
||||
self.indices_for_choices = None |
||||
if self.choices: |
||||
# Get indices for each choice |
||||
self._get_choices_indices(language) |
||||
|
||||
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} |
||||
|
||||
bar = tqdm( |
||||
range(len(data_loader)), |
||||
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps", |
||||
disable=not is_rank_0(), |
||||
) |
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
||||
|
||||
answers = [] |
||||
|
||||
for i, batch in enumerate(data_loader): |
||||
batch_prompt, batch_target = get_batch_prompt( |
||||
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length |
||||
) |
||||
|
||||
if is_rank_0() and debug and i == 0: |
||||
self.logger.info( |
||||
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" |
||||
) |
||||
self.logger.info("-" * 120) |
||||
self.logger.info("An example prompt and prompt with target is:") |
||||
self.logger.info("-" * 120) |
||||
self.logger.info(batch_prompt[0]) |
||||
self.logger.info("-" * 120) |
||||
self.logger.info(batch_prompt[0] + batch_target[0][0]) |
||||
|
||||
if not calculate_overall_loss: |
||||
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) |
||||
|
||||
if calculate_loss: |
||||
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( |
||||
batch_prompt, batch_target, calculate_overall_loss |
||||
) |
||||
|
||||
probs = [] |
||||
if self.indices_for_choices: |
||||
scores = scores.to(torch.float32) |
||||
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample. |
||||
# Otherwise this will violate the single-choice setting. |
||||
|
||||
if calculate_loss: |
||||
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] |
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() |
||||
|
||||
probs = scores.numpy().tolist() |
||||
probs = [ |
||||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) |
||||
] |
||||
|
||||
for j in range(len(batch)): |
||||
if not calculate_overall_loss: |
||||
if isinstance(batch[j]["output"], list): |
||||
batch[j]["output"].append(batch_decodes[j].strip()) |
||||
else: |
||||
batch[j]["output"] = batch_decodes[j].strip() |
||||
|
||||
if isinstance(scores, torch.Tensor): |
||||
batch[j]["logits_over_choices"] = probs[j] |
||||
|
||||
if calculate_loss: |
||||
batch[j]["loss_over_choices"] = loss_over_choices[j] |
||||
|
||||
if calculate_loss: |
||||
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() |
||||
|
||||
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. |
||||
# However, loss (which is per sample loss) suffices for most cases. |
||||
batch[j]["loss_sum"] = batch_losses[j] |
||||
batch[j]["token_num"] = batch_target_token_nums[j] |
||||
|
||||
if batch_bytes_nums: |
||||
batch[j]["byte_num"] = batch_bytes_nums[j] |
||||
answers.extend(batch) |
||||
|
||||
bar.update() |
||||
|
||||
return answers |
||||
|
||||
@torch.no_grad() |
||||
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: |
||||
"""Generate results given a list of inputs and get logits of the first new token over choices. |
||||
|
||||
Args: |
||||
inputs: A list of strings. |
||||
max_new_tokens: Max new tokens for generation. |
||||
kwargs: Key arguments for generation |
||||
|
||||
Returns: |
||||
A list of generated strings and logits over choices. |
||||
|
||||
Note: |
||||
Currently the function only returns the logits of the first new token. |
||||
It is used for single choice question. |
||||
For multiple choices question, please avoid using the loss over choices. |
||||
You should set argument choices as None in self.inference(). |
||||
|
||||
""" |
||||
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) |
||||
|
||||
generation_kwargs = kwargs.copy() |
||||
generation_kwargs.update({"max_tokens": max_new_tokens}) |
||||
logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) |
||||
|
||||
sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) |
||||
|
||||
outputs = self.model.generate(truncated_inputs, sampling_kwargs) |
||||
output_strs = [] |
||||
for output in outputs: |
||||
generated_text = output.outputs[0].text |
||||
output_strs.append(generated_text) |
||||
scores = logits_processor.get_target_logits() |
||||
|
||||
return output_strs, scores |
||||
|
||||
@torch.no_grad() |
||||
def get_loss( |
||||
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool |
||||
) -> List[List[float]]: |
||||
""" |
||||
Calculate loss only on target tokens. |
||||
|
||||
Args: |
||||
batch: A batch of prompt without target answer. |
||||
batch_target: A batch of target answer. Sometimes one question can have multiple target answers. |
||||
|
||||
Returns: |
||||
Loss. |
||||
|
||||
""" |
||||
|
||||
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. |
||||
# We don't need to generate new tokens. |
||||
# Target answer's length is usually << model_max_length, but we still call it in case. |
||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. |
||||
if not calculate_overall_loss: |
||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] |
||||
|
||||
# Get the number of target answers for different questions |
||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target] |
||||
|
||||
if calculate_overall_loss: |
||||
batch = [] |
||||
bytes_list = [] |
||||
batch_prompt_pretrain = [] |
||||
for p, b in zip(batch_prompt, batch_target): |
||||
batch.append(p + b[0]) |
||||
|
||||
for input in batch: |
||||
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. |
||||
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. |
||||
# After all, the rest of the original string doesn't need to be tokenized at the first place. |
||||
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. |
||||
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. |
||||
# After all, the rest of the original string doesn't need to be tokenized at the first place. |
||||
ratio = [16, 8, 4, 2, 1] |
||||
tokenized = None |
||||
for r in ratio: |
||||
tokenized = self.tokenizer( |
||||
[input[0 : len(input) // r]], |
||||
truncation=True, |
||||
max_length=self.model_max_length, |
||||
return_tensors="pt", |
||||
) |
||||
if tokenized.input_ids.size(1) >= self.model_max_length: |
||||
break |
||||
|
||||
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) |
||||
batch_prompt_pretrain.append(string) |
||||
bytes_list.append(len(string.encode("utf-8"))) |
||||
|
||||
batch_prompt = copy.deepcopy(batch_prompt_pretrain) |
||||
batch_target = None |
||||
else: |
||||
batch_prompt_processed = [] |
||||
batch_target_processed = [] |
||||
for prompt, targets in zip(batch_prompt, batch_target): |
||||
for target in targets: |
||||
target_tokenized = self.tokenizer( |
||||
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt" |
||||
) |
||||
max_new_tokens = target_tokenized["input_ids"][0].size(0) |
||||
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] |
||||
batch_prompt_processed.append(prompt_with_correct_length) |
||||
batch_target_processed.append(target) |
||||
|
||||
batch_prompt = copy.deepcopy(batch_prompt_processed) |
||||
batch_target = copy.deepcopy(batch_target_processed) |
||||
bytes_list = None |
||||
|
||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size. |
||||
# We will generate new batches. |
||||
losses = [] |
||||
target_token_nums = [] |
||||
|
||||
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) |
||||
losses.extend(losses_per_batch) |
||||
target_token_nums.extend(target_token_num_per_batch) |
||||
|
||||
start_indice = 0 |
||||
losses_per_sample = [] |
||||
|
||||
target_token_nums_per_sample = [] |
||||
bytes_nums_per_sample = [] |
||||
for length in batch_target_nums: |
||||
losses_per_sample.append(losses[start_indice : start_indice + length]) |
||||
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) |
||||
|
||||
if bytes_list: |
||||
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length]) |
||||
|
||||
start_indice += length |
||||
|
||||
if bytes_list: |
||||
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample |
||||
|
||||
return losses_per_sample, target_token_nums_per_sample, None |
||||
|
||||
|
||||
class GetTokenLogitsProcessor: |
||||
""" |
||||
LogitsProcessor to get specific logits |
||||
|
||||
Args: |
||||
indices_for_choices: token indices of required tokens |
||||
target_logits: store all the target logits |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
indices_for_choices: List[List[int]], |
||||
): |
||||
self.indices_for_choices = (indices_for_choices,) |
||||
self.target_logits = [] |
||||
|
||||
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: |
||||
choice_scores = [] |
||||
|
||||
if not input_ids: |
||||
for option_indices in self.indices_for_choices[0]: |
||||
choice_scores.append(logits[option_indices].detach().cpu()) |
||||
|
||||
choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0] |
||||
self.target_logits.append(choice_scores) |
||||
|
||||
return logits |
||||
|
||||
def get_target_logits(self) -> torch.Tensor: |
||||
return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) |
Loading…
Reference in new issue