mirror of https://github.com/hpcaitech/ColossalAI
239 lines
8.5 KiB
Python
239 lines
8.5 KiB
Python
from itertools import count
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from colossalai.inference.config import InferenceConfig
|
|
from colossalai.inference.modeling.policy import model_policy_map
|
|
from colossalai.inference.struct import Sequence
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.shardformer.policies.base_policy import Policy
|
|
|
|
from .request_handler import RequestHandler
|
|
|
|
PP_AXIS, TP_AXIS = 0, 1
|
|
|
|
_supported_models = [
|
|
"LlamaForCausalLM",
|
|
]
|
|
|
|
|
|
class InferenceEngine:
|
|
|
|
"""
|
|
InferenceEngine which manages the inference process..
|
|
|
|
Args:
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
|
verbose (bool): Determine whether or not to log the generation process.
|
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
inference_config: Optional["InferenceConfig"] = None,
|
|
verbose: bool = False,
|
|
model_policy: Policy = None,
|
|
) -> None:
|
|
assert inference_config, "Please provide inference_config."
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
self.inference_config = inference_config
|
|
self.model_config = model.config
|
|
self.device = torch.device("cuda")
|
|
|
|
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
|
self.dtype = torch.float32
|
|
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
|
self.dtype = torch.float16
|
|
model.half()
|
|
else:
|
|
self.dtype = torch.bfloat16
|
|
model.to(torch.bfloat16)
|
|
|
|
if model_policy is None:
|
|
model_policy = model_policy_map[self.model_config.model_type]()
|
|
|
|
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
|
|
|
self.model = self._shardformer(
|
|
model,
|
|
model_policy,
|
|
None,
|
|
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
|
)
|
|
|
|
self.verbose = verbose
|
|
if verbose:
|
|
self.logger = get_dist_logger(__name__)
|
|
|
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
|
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
|
self.counter = count()
|
|
|
|
def _verify_config(self) -> None:
|
|
"""
|
|
Verify the input config
|
|
"""
|
|
if not isinstance(self.model, nn.Module):
|
|
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
|
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
|
self.tokenizer, PreTrainedTokenizer
|
|
):
|
|
raise TypeError(
|
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
|
)
|
|
assert (
|
|
self.model.__class__.__name__ in _supported_models
|
|
), f"Model {self.model.__class__.__name__} is not supported."
|
|
|
|
def _shardformer(
|
|
self,
|
|
model: nn.Module,
|
|
model_policy: Policy,
|
|
stage_manager: PipelineStageManager = None,
|
|
tp_group: ProcessGroupMesh = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Initialize ShardConfig and replace the model with shardformer.
|
|
|
|
Args:
|
|
model (nn.Module): Path or nn.Module of this model.
|
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
|
|
|
Returns:
|
|
nn.Module: _description_
|
|
"""
|
|
shardconfig = ShardConfig(
|
|
tensor_parallel_process_group=tp_group,
|
|
pipeline_stage_manager=stage_manager,
|
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
|
enable_fused_normalization=False,
|
|
enable_all_optimization=False,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
enable_sequence_parallelism=False,
|
|
extra_kwargs={"quant": self.inference_config.quant_mode},
|
|
)
|
|
shardformer = ShardFormer(shard_config=shardconfig)
|
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
|
return shard_model.cuda()
|
|
|
|
def generate(
|
|
self,
|
|
generation_config: GenerationConfig = None,
|
|
) -> List[str]:
|
|
"""
|
|
Executing the inference step.
|
|
|
|
Args:
|
|
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
|
|
|
Returns:
|
|
List[str]: Inference result returned by one generation.
|
|
"""
|
|
|
|
self.generation_config = generation_config
|
|
|
|
output_list = []
|
|
|
|
while self.request_handler.check_unfinished_seqs():
|
|
output_list += self.step()
|
|
|
|
return output_list
|
|
|
|
def add_request(
|
|
self,
|
|
requests_id: List[int] = None,
|
|
prompts: List[str] = None,
|
|
prompts_token_ids: List[int] = None,
|
|
) -> None:
|
|
"""
|
|
Add requests.
|
|
|
|
Args:
|
|
requests_id (List[int], optional): The request ID. Defaults to None.
|
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
|
"""
|
|
|
|
block_size = self.inference_config.block_size
|
|
|
|
if prompts_token_ids is None:
|
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
|
|
|
assert (
|
|
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
|
), "The length of input prompts must be less than max_input_len."
|
|
|
|
prompts_num = len(prompts_token_ids)
|
|
|
|
for i in range(prompts_num):
|
|
if requests_id:
|
|
request_id = requests_id[i]
|
|
else:
|
|
request_id = next(self.counter)
|
|
if prompts == None:
|
|
prompt = None
|
|
else:
|
|
prompt = prompts[i]
|
|
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
|
sequence = Sequence(
|
|
request_id,
|
|
prompt,
|
|
prompts_token_ids[i],
|
|
block_size,
|
|
None,
|
|
block_table,
|
|
self.tokenizer.eos_token_id,
|
|
self.inference_config.max_output_len,
|
|
)
|
|
self.request_handler.add_sequence(sequence)
|
|
|
|
def step(self) -> List[str]:
|
|
"""
|
|
In each step, do the follows:
|
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
|
2. Run model to generate the next token
|
|
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
|
4. Decode and return finished sequences.
|
|
|
|
Returns:
|
|
List[str]: Decoded finished sequences generated by one step.
|
|
"""
|
|
|
|
if self.verbose:
|
|
self.logger.info("Running generation step")
|
|
|
|
output_list = []
|
|
batch = self.request_handler.schedule()
|
|
|
|
logits = self.model(
|
|
batch,
|
|
self.k_cahce,
|
|
self.v_cache,
|
|
)
|
|
|
|
logits = logits[:, -1, :]
|
|
self.request_handler.search_tokens(self.generation_config, logits)
|
|
finished_sequences = self.request_handler.update()
|
|
|
|
# Decode completed sentences.
|
|
for seq in finished_sequences:
|
|
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
|
output_list.append(output_str)
|
|
|
|
return output_list
|