ColossalAI/applications/ColossalEval/colossal_eval/models/vllm.py

499 lines
21 KiB
Python
Raw Normal View History

import copy
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
from torch.utils.data import DataLoader
from tqdm import tqdm
from vllm import LLM, SamplingParams
from colossalai.logging import DistributedLogger
from .huggingface import HuggingFaceModel
IGNORE_INDEX = -100
class vLLMModel(HuggingFaceModel):
"""
Model wrapper around vLLM models.
Args:
path: The path to a vLLM model.
model_max_length: The maximum sequence length of the model.
tokenizer_path: The path to the tokenizer.
tokenizer_kwargs: Keyword arguments for the tokenizer.
model_kwargs: Keyword arguments for the model.
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
quantization: The method used to quantize the model weights
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
enforce_eager: Whether to enforce eager execution.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
path: str,
model_max_length: int = 2048,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: Dict = None,
model_kwargs: Dict = None,
prompt_template: Conversation = None,
batch_size: int = 1,
logger: DistributedLogger = None,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
gpu_memory_utilization: float = 0.5,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
):
super().__init__(
path=path,
model_max_length=model_max_length,
prompt_template=prompt_template,
batch_size=batch_size,
logger=logger,
)
self._load_model(
path=path,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
tokenizer_path=tokenizer_path if tokenizer_path else None,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
quantization=quantization,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
)
def _load_model(
self,
path: str,
model_kwargs: dict,
tokenizer_kwargs: dict,
tokenizer_path: Optional[str] = None,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
):
"""
Load model.
Args:
path: The path to the model.
model_kwargs: Keyword arguments for the model.
tokenizer_kwargs: Keyword arguments for the tokenizer.
tokenizer_path: The path to the tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
quantization: The method used to quantize the model weights
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
enforce_eager: Whether to enforce eager execution.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
disable_custom_all_reduce: See ParallelConfig
"""
if "torch_dtype" in model_kwargs:
model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"])
model_kwargs.pop("torch_dtype")
else:
model_kwargs.setdefault("dtype", torch.float16)
if "trust_remote_code" in model_kwargs:
trust_remote_code = model_kwargs["trust_remote_code"]
model_kwargs.pop("trust_remote_code")
if "trust_remote_code" in tokenizer_kwargs:
trust_remote_code = tokenizer_kwargs["trust_remote_code"]
tokenizer_kwargs.pop("trust_remote_code")
self.model = LLM(
model=path,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
quantization=quantization,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**model_kwargs,
**tokenizer_kwargs,
)
self.tokenizer = self.model.get_tokenizer()
if self.batch_size > 1:
self.tokenizer.padding_side = "left"
self.tokenizer.truncation_side = "left"
if self.tokenizer.pad_token_id is None:
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
elif hasattr(self.tokenizer, "eod_id"):
# Qwen has an eod token "<|endoftext|>".
self.tokenizer.pad_token_id = self.tokenizer.eod_id
else:
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
raise ValueError(
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
"Please set pad_token_id manually."
)
def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]:
"""
Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110
Args:
input_ids_list: A batch of input string.
labels: A batch of labels.
Returns:
A list of loss and a list of label length.
"""
batch_size = len(inputs)
sampling_kwargs = SamplingParams(logprobs=1)
outputs = self.model.generate(inputs, sampling_kwargs)
ce_loss = []
if labels is not None:
lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels]
else:
lens = [1] * batch_size
for i in range(batch_size):
logprobs = outputs[i].outputs[0].logprobs
token_ids = outputs[i].outputs[0].token_ids
logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))]
logprobs_list = [i.logprob for i in logprobs_list]
logprobs_list = np.array(logprobs_list)
if lens is not None:
logprobs_list = logprobs_list[: lens[i]]
loss = -logprobs_list.sum(axis=-1) / lens[i]
ce_loss.append(loss)
batch_loss = np.array(ce_loss)
return batch_loss, lens
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
"""
Infer the given data.
This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits.
Args:
data: The data for inference.
inference_kwargs: Arguments for inference.
debug: Whether to display generated prompt for debugging.
Returns:
Inference results.
"""
calculate_loss = inference_kwargs["calculate_loss"]
classes = inference_kwargs["all_classes"]
language = inference_kwargs["language"]
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
max_new_tokens = inference_kwargs["max_new_tokens"]
few_shot_data = inference_kwargs.get("few_shot_data", None)
# Some classification questions' options are texts not a single letter such as A, B, C and D.
# If the text length is greater than 1, we won't calculate loss over choices.
if classes is not None and any(len(c) > 1 for c in classes):
classes = None
self.choices = classes
self.indices_for_choices = None
if self.choices:
# Get indices for each choice
self._get_choices_indices(language)
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
bar = tqdm(
range(len(data_loader)),
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
answers = []
for i, batch in enumerate(data_loader):
batch_prompt, batch_target = get_batch_prompt(
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
)
if is_rank_0() and debug and i == 0:
self.logger.info(
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
)
self.logger.info("-" * 120)
self.logger.info("An example prompt and prompt with target is:")
self.logger.info("-" * 120)
self.logger.info(batch_prompt[0])
self.logger.info("-" * 120)
self.logger.info(batch_prompt[0] + batch_target[0][0])
if not calculate_overall_loss:
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
if calculate_loss:
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
batch_prompt, batch_target, calculate_overall_loss
)
probs = []
if self.indices_for_choices:
scores = scores.to(torch.float32)
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
# Otherwise this will violate the single-choice setting.
if calculate_loss:
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
probs = scores.numpy().tolist()
probs = [
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
]
for j in range(len(batch)):
if not calculate_overall_loss:
if isinstance(batch[j]["output"], list):
batch[j]["output"].append(batch_decodes[j].strip())
else:
batch[j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
batch[j]["logits_over_choices"] = probs[j]
if calculate_loss:
batch[j]["loss_over_choices"] = loss_over_choices[j]
if calculate_loss:
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases.
batch[j]["loss_sum"] = batch_losses[j]
batch[j]["token_num"] = batch_target_token_nums[j]
if batch_bytes_nums:
batch[j]["byte_num"] = batch_bytes_nums[j]
answers.extend(batch)
bar.update()
return answers
@torch.no_grad()
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
"""Generate results given a list of inputs and get logits of the first new token over choices.
Args:
inputs: A list of strings.
max_new_tokens: Max new tokens for generation.
kwargs: Key arguments for generation
Returns:
A list of generated strings and logits over choices.
Note:
Currently the function only returns the logits of the first new token.
It is used for single choice question.
For multiple choices question, please avoid using the loss over choices.
You should set argument choices as None in self.inference().
"""
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
generation_kwargs = kwargs.copy()
generation_kwargs.update({"max_tokens": max_new_tokens})
logits_processor = GetTokenLogitsProcessor(self.indices_for_choices)
sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs)
outputs = self.model.generate(truncated_inputs, sampling_kwargs)
output_strs = []
for output in outputs:
generated_text = output.outputs[0].text
output_strs.append(generated_text)
scores = logits_processor.get_target_logits()
return output_strs, scores
@torch.no_grad()
def get_loss(
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
) -> List[List[float]]:
"""
Calculate loss only on target tokens.
Args:
batch: A batch of prompt without target answer.
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
Returns:
Loss.
"""
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
# We don't need to generate new tokens.
# Target answer's length is usually << model_max_length, but we still call it in case.
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
if not calculate_overall_loss:
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
# Get the number of target answers for different questions
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
if calculate_overall_loss:
batch = []
bytes_list = []
batch_prompt_pretrain = []
for p, b in zip(batch_prompt, batch_target):
batch.append(p + b[0])
for input in batch:
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
# After all, the rest of the original string doesn't need to be tokenized at the first place.
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
# After all, the rest of the original string doesn't need to be tokenized at the first place.
ratio = [16, 8, 4, 2, 1]
tokenized = None
for r in ratio:
tokenized = self.tokenizer(
[input[0 : len(input) // r]],
truncation=True,
max_length=self.model_max_length,
return_tensors="pt",
)
if tokenized.input_ids.size(1) >= self.model_max_length:
break
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
batch_prompt_pretrain.append(string)
bytes_list.append(len(string.encode("utf-8")))
batch_prompt = copy.deepcopy(batch_prompt_pretrain)
batch_target = None
else:
batch_prompt_processed = []
batch_target_processed = []
for prompt, targets in zip(batch_prompt, batch_target):
for target in targets:
target_tokenized = self.tokenizer(
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
)
max_new_tokens = target_tokenized["input_ids"][0].size(0)
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
batch_prompt_processed.append(prompt_with_correct_length)
batch_target_processed.append(target)
batch_prompt = copy.deepcopy(batch_prompt_processed)
batch_target = copy.deepcopy(batch_target_processed)
bytes_list = None
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
# We will generate new batches.
losses = []
target_token_nums = []
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target)
losses.extend(losses_per_batch)
target_token_nums.extend(target_token_num_per_batch)
start_indice = 0
losses_per_sample = []
target_token_nums_per_sample = []
bytes_nums_per_sample = []
for length in batch_target_nums:
losses_per_sample.append(losses[start_indice : start_indice + length])
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
if bytes_list:
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
start_indice += length
if bytes_list:
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
return losses_per_sample, target_token_nums_per_sample, None
class GetTokenLogitsProcessor:
"""
LogitsProcessor to get specific logits
Args:
indices_for_choices: token indices of required tokens
target_logits: store all the target logits
"""
def __init__(
self,
indices_for_choices: List[List[int]],
):
self.indices_for_choices = (indices_for_choices,)
self.target_logits = []
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
choice_scores = []
if not input_ids:
for option_indices in self.indices_for_choices[0]:
choice_scores.append(logits[option_indices].detach().cpu())
choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0]
self.target_logits.append(choice_scores)
return logits
def get_target_logits(self) -> torch.Tensor:
return torch.stack(self.target_logits) if self.target_logits else torch.tensor([])