mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txtpull/5258/head
parent
93aeacca34
commit
8daee26989
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
|
@ -15,44 +15,44 @@ class InferenceConfig:
|
|||
"""The inference configuration.
|
||||
|
||||
Args:
|
||||
model: Path or nn.Module of this model.
|
||||
tokenizer: Path of the tokenizer to use.
|
||||
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Whether to trust remote code from huggingface.
|
||||
max_batch_size: Maximum batch size.
|
||||
max_output_len: Maximum output length.
|
||||
max_input_len: Maximum input length.
|
||||
block_size: The number of blocks in a logical block.
|
||||
dtype: The data type for weights and activations.
|
||||
tp_size: Tensor parallel size.
|
||||
pp_size: Pipeline parallel size.
|
||||
max_seq_len: Maximum length of input sentence.
|
||||
quant_mode: Quantization mode.
|
||||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
beam_width: The maximum beam width used to initialize KV Cache.
|
||||
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
max_batch_size (int): Maximum batch size.
|
||||
max_output_len (int): Maximum output length.
|
||||
max_input_len (int): Maximum input length.
|
||||
block_size (int): The number of blocks in a logical block.
|
||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||
tp_size (int): Tensor parallel size.
|
||||
pp_size (int): Pipeline parallel size.
|
||||
max_seq_len (int): Maximum length of input sentence.
|
||||
beam_width (int): The maximum beam width used to initialize KV Cache.
|
||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||
when the actual value exceeds this ratio.
|
||||
quant_mode (Optional[str]): Quantization mode.
|
||||
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
"""
|
||||
|
||||
model: Union[str, nn.Module]
|
||||
tokenizer: str = None
|
||||
tokenizer_mode: str = "auto"
|
||||
trust_remote_code: bool = False
|
||||
max_batch_size: int = None
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
max_batch_size: int = 8
|
||||
max_output_len: int = 256
|
||||
max_input_len: int = 256
|
||||
block_size: int = 16
|
||||
dtype: Union[str, torch.dtype] = torch.float32
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
max_seq_len: Optional[int] = None
|
||||
max_seq_len: int = 512
|
||||
# TODO: beam search is not support for now
|
||||
beam_width: int = 1
|
||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
prefill_ratio: Optional[float] = 1.2
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
beam_width: int = 1
|
||||
# TODO: beam search is not support for now
|
||||
prefill_ratio: Optional[float] = 1.2
|
||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
|
||||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_config()
|
||||
|
||||
def _init_batch_size(self):
|
||||
"""
|
||||
|
@ -75,10 +75,20 @@ class InferenceConfig:
|
|||
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self):
|
||||
if self.tokenizer_mode not in ["auto", "slow"]:
|
||||
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Verify the input config
|
||||
"""
|
||||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
assert self.dtype in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
"bf16",
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
|
|
|
@ -1,65 +1,232 @@
|
|||
from logging import Logger
|
||||
from typing import Optional
|
||||
from itertools import count
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""
|
||||
InferenceEngine is the core component for Inference.
|
||||
|
||||
It is responsible for launch the inference process, including:
|
||||
- Initialize model and distributed training environment(if needed)
|
||||
- Launch request_handler and corresponding kv cache manager
|
||||
- Receive requests and generate texts.
|
||||
- Log the generation process
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
tokenizer: Path of the tokenizer to use.
|
||||
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: str = None,
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: Optional["InferenceConfig"] = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
assert inference_config, "Please provide inference_config."
|
||||
|
||||
self._init_model()
|
||||
# cache_config may need to be modified later.
|
||||
# self.request_handler = RequestHandler(cache_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_model_config = AutoConfig.from_pretrained(
|
||||
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
||||
)
|
||||
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
self.logger = Logger()
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
def _init_model(self):
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Initialize model and distributed training environment(if needed).
|
||||
May need to provide two different initialization methods:
|
||||
1. 用户自定义(from local path)
|
||||
2. 从checkpoint加载(hugging face)
|
||||
Verify the input config
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: _description_
|
||||
"""
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"quant": self.inference_config.quant_mode},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
generation_config: GenerationConfig = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
def _verify_config(self):
|
||||
self.generation_config = generation_config
|
||||
|
||||
output_list = []
|
||||
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_list += self.step()
|
||||
|
||||
return output_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: List[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify the configuration to avoid potential bugs.
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
requests_id (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
def generate(self):
|
||||
pass
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
def step(self):
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = []
|
||||
for prompt in prompts:
|
||||
prompts_token_ids.append(self.tokenizer.encode(prompt))
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if requests_id:
|
||||
request_id = requests_id[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run request_handler to update the kv cache and running input_ids
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Run model to generate the next token
|
||||
3. Check whether there is finied request and decode
|
||||
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
4. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
self.request_handler.schedule()
|
||||
|
||||
# Uncomment if the development of RequestHandler is completed.
|
||||
# logits = self.model(batch)
|
||||
# self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
else:
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
return output_list
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""
|
||||
|
@ -7,14 +9,17 @@ class RequestHandler:
|
|||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
||||
Args:
|
||||
cache_config: Configuration for initialize and manage kv cache.
|
||||
inference_config: Store the configuration information related to inference.
|
||||
model_config: The huggingface model config.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_config) -> None:
|
||||
self.cache_config = cache_config
|
||||
def __init__(self, inference_config, model_config) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model_config
|
||||
self._init_cache()
|
||||
self.waiting_list: List["Reqseq"] = []
|
||||
self.running_list: List["Reqseq"] = []
|
||||
self.waiting_list: List["Sequence"] = []
|
||||
self.running_list: List["Sequence"] = []
|
||||
self.batch = BatchInfo.init_batch()
|
||||
|
||||
def _init_cache(self):
|
||||
"""
|
||||
|
@ -25,12 +30,17 @@ class RequestHandler:
|
|||
"""
|
||||
The main logic of request handler.
|
||||
"""
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
if self.waiting_list:
|
||||
self.running_list = self.waiting_list
|
||||
self.batch.add_seqs(self.running_list)
|
||||
return self.batch
|
||||
|
||||
def add_sequence(self, reqseq: "Reqseq"):
|
||||
def add_sequence(self, req_seq: "Sequence"):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
self.waiting_list.append(reqseq)
|
||||
self.waiting_list.append(req_seq)
|
||||
|
||||
def abort_sequence(self, seq_id: str):
|
||||
"""
|
||||
|
@ -39,10 +49,23 @@ class RequestHandler:
|
|||
self._find_sequence(seq_id)
|
||||
return
|
||||
|
||||
def _find_sequence(self, seq_id: str) -> "Reqseq":
|
||||
def _find_sequence(self, seq_id: str) -> "Sequence":
|
||||
"""
|
||||
Find the request by seq_id.
|
||||
"""
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self.waiting_list or self.running_list
|
||||
return len(self.waiting_list) != 0 or len(self.running_list) != 0
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update the waiting list and running list.
|
||||
"""
|
||||
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
self.waiting_list = []
|
||||
self.running_list = []
|
||||
finished_sequences = list(self.batch.sequences_set)
|
||||
|
||||
self.batch.clear_batch()
|
||||
return finished_sequences
|
||||
|
|
|
@ -135,7 +135,7 @@ class KVCacheManager:
|
|||
and updates the provided block table with the allocated block ids.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece.
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
|
@ -185,7 +185,7 @@ class KVCacheManager:
|
|||
and updates the provided block table if a new cache block is needed.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
context_len: The length of the processing sequnece (already-allocated length).
|
||||
"""
|
||||
assert block_table.dim() == 1
|
||||
|
@ -199,7 +199,7 @@ class KVCacheManager:
|
|||
and updates the provided block table with the allocated block.
|
||||
|
||||
Args:
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
||||
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||
block_local_idx: The index of the block in the block table.
|
||||
space_asked: i.e. The number of tokens to be assigned space for.
|
||||
Returns:
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"llama": LlamaModelInferPolicy,
|
||||
}
|
||||
|
||||
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]
|
|
@ -0,0 +1,7 @@
|
|||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
# The code here just for test and will be modified later.
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
|
@ -1,68 +1,82 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Set
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
"""
|
||||
The abstraction of request and sequence are defined here.
|
||||
"""
|
||||
|
||||
|
||||
class RequsetStatus(enum.Enum):
|
||||
"""The status of Sentences"""
|
||||
class RequestStatus(enum.Enum):
|
||||
"""
|
||||
The status of Sentences
|
||||
"""
|
||||
|
||||
# running status
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREFILL = enum.auto()
|
||||
TOKEN = enum.auto()
|
||||
ABORTED = enum.auto()
|
||||
|
||||
# completion status
|
||||
OVERLENGTH = enum.auto()
|
||||
COMPLETED = enum.auto()
|
||||
LENGTH_CAPPED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequsetStatus") -> bool:
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
RequsetStatus.OVERLENGTH,
|
||||
RequsetStatus.COMPLETED,
|
||||
RequsetStatus.LENGTH_CAPPED,
|
||||
RequestStatus.OVERLENGTH,
|
||||
RequestStatus.COMPLETED,
|
||||
RequestStatus.LENGTH_CAPPED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def is_running(status: "RequsetStatus") -> bool:
|
||||
return status == RequsetStatus.RUNNING
|
||||
def is_running(status: "RequestStatus") -> bool:
|
||||
return status in [
|
||||
RequestStatus.PREFILL,
|
||||
RequestStatus.TOKEN,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def is_waiting(status: "RequsetStatus") -> bool:
|
||||
return status == RequsetStatus.WAITING
|
||||
def is_waiting(status: "RequestStatus") -> bool:
|
||||
return status == RequestStatus.WAITING
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sequence:
|
||||
"""Store information of input sequence.
|
||||
|
||||
Args:
|
||||
request_id: The ID of input sequence.
|
||||
prompt: The prompt of input sequence.
|
||||
token_id: The tokens ID of input sequence.
|
||||
block_size: The block size of input sequence.
|
||||
sample_params: The sample_params of input sequence.
|
||||
block_table_index: The index of input sequence in block_table.
|
||||
request_id (int): The ID of input sequence.
|
||||
prompt (str): The prompt of input sequence.
|
||||
input_token_id (List[int]): The tokens ID of input sequence.
|
||||
block_size (int): The block size of input sequence.
|
||||
sample_params (SampleParams): The sample_params of input sequence.
|
||||
block_table (torch.Tensor): The index of input sequence in block_table.
|
||||
eos_token_id (int): The eos token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: str,
|
||||
token_id: List[int],
|
||||
block_size: int,
|
||||
sample_params, # SampleParams needs to be imported later.
|
||||
block_table_index: int,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.input_token_id = token_id
|
||||
self.blokc_size = block_size
|
||||
self.sample_params = sample_params
|
||||
request_id: int
|
||||
prompt: str
|
||||
input_token_id: List[int]
|
||||
block_size: int
|
||||
sample_params: any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
max_output_len: int = 256
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
self.status = RequsetStatus.WAITING
|
||||
self.block_table_index = block_table_index
|
||||
self.status = RequestStatus.WAITING
|
||||
|
||||
def get_sentence_len(self) -> None:
|
||||
"""
|
||||
|
@ -84,17 +98,30 @@ class Sequence:
|
|||
|
||||
def check_finish(self) -> bool:
|
||||
"""
|
||||
Check whether inference is over.
|
||||
Check whether the inference is finished.
|
||||
|
||||
Returns:
|
||||
bool: Whether the inference is finished.
|
||||
"""
|
||||
return RequsetStatus.is_finished(self.status)
|
||||
if RequestStatus.is_finished(self.status):
|
||||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.request_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Request ID(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical block number={len(self._logical_blocks)}"
|
||||
f"sample_params={self.sample_params}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -104,34 +131,38 @@ class BatchInfo:
|
|||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
sequences_set: Set[Sequence]
|
||||
block_table: Dict[int, int] = None
|
||||
sequences_set: OrderedSet["Sequence"]
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
|
||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
Args:
|
||||
seqs (List[Sequence]): List of input sequence.
|
||||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
sequences_set = set()
|
||||
block_table = {}
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
assert (
|
||||
seq.request_id in block_table.keys()
|
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||
continue
|
||||
|
||||
assert (
|
||||
seq.request_id not in block_table.keys()
|
||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||
sequences_set = OrderedSet()
|
||||
|
||||
sequences_set.add(seq)
|
||||
block_table[seq.request_id] = seq.block_table_index
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
|
||||
return cls(sequences_set=sequences_set, block_table=block_table)
|
||||
sequences_set.add(seq)
|
||||
|
||||
return cls(sequences_set=sequences_set)
|
||||
|
||||
def get_block_table_tensor(self):
|
||||
tesnor_list = []
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
return torch.concat(tesnor_list)
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
|
@ -139,35 +170,76 @@ class BatchInfo:
|
|||
"""
|
||||
for seq in self.sequences_set:
|
||||
if not seq.check_finish():
|
||||
seq.status = RequsetStatus.ABORTED
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.clear()
|
||||
self.block_table.clear()
|
||||
|
||||
def fliter_batch(self) -> None:
|
||||
def fliter_batch(self) -> List["Sequence"]:
|
||||
"""
|
||||
Remove completed sentences from a batch.
|
||||
"""
|
||||
for seq in self.sequences_set.copy():
|
||||
if seq.check_finish():
|
||||
self.sequences_set.remove(seq)
|
||||
del self.block_table[seq.request_id]
|
||||
|
||||
def add_seqs(self, seqs: List[Sequence]) -> None:
|
||||
Returns:
|
||||
List["Sequence"]: List of finished sequences.
|
||||
"""
|
||||
finish_seqs = []
|
||||
for seq in self.sequences_set:
|
||||
if seq.check_finish():
|
||||
finish_seqs.append(seq)
|
||||
for finish_seq in finish_seqs:
|
||||
self.sequences_set.discard(finish_seq)
|
||||
return finish_seqs
|
||||
|
||||
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
||||
"""
|
||||
Remove sequence from the batch.
|
||||
"""
|
||||
if not seq.check_finish():
|
||||
seq.status = RequestStatus.ABORTED
|
||||
self.sequences_set.discard(seq)
|
||||
return seq
|
||||
|
||||
def add_seqs(self, seqs: List["Sequence"]) -> None:
|
||||
"""
|
||||
Add new sequence to batch
|
||||
|
||||
Args:
|
||||
seqs (List[Sequence]): The list of new sequences.
|
||||
seqs (List["Sequence"]): The list of new sequences.
|
||||
"""
|
||||
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
print("The sequence is already in sequences_set.")
|
||||
assert (
|
||||
seq.request_id in self.block_table
|
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
assert (
|
||||
seq.request_id not in self.block_table
|
||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||
self.sequences_set.add(seq)
|
||||
self.block_table[seq.request_id] = seq.block_table_index
|
||||
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check whether sequences_set is empty.
|
||||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
if not isinstance(token, list):
|
||||
if not isinstance(token, int):
|
||||
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
|
||||
token = [token]
|
||||
seq.output_token_id += token
|
||||
seq.check_finish()
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""
|
||||
Get batch_size of this batch
|
||||
"""
|
||||
return len(self.sequences_set)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
ordered_set
|
||||
transformers==4.34.0
|
||||
auto-gptq==0.5.0
|
||||
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
|
@ -1,4 +1,6 @@
|
|||
diffusers
|
||||
fbgemm-gpu==0.2.0
|
||||
ordered_set
|
||||
pytest
|
||||
coverage==7.2.3
|
||||
git+https://github.com/hpcaitech/pytest-testmon
|
||||
|
|
|
@ -1,26 +1,45 @@
|
|||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def test_config_and_inferenceData():
|
||||
config = InferenceConfig("/llama")
|
||||
assert config.max_batch_size
|
||||
def check_config_and_inference():
|
||||
config = InferenceConfig()
|
||||
assert config.max_batch_size == 8
|
||||
sequence = Sequence(
|
||||
request_id=1,
|
||||
prompt="abc",
|
||||
token_id=[1, 2, 3],
|
||||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table_index=1,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence2 = Sequence(
|
||||
request_id=2,
|
||||
prompt="bcd",
|
||||
token_id=[4, 5, 6],
|
||||
input_token_id=[4, 5, 6],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table_index=2,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
sequence3 = Sequence(
|
||||
request_id=3,
|
||||
prompt="efg",
|
||||
input_token_id=[7, 8, 9],
|
||||
block_size=16,
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
assert sequence.get_sentence_len() == 3
|
||||
|
@ -29,15 +48,34 @@ def test_config_and_inferenceData():
|
|||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo.init_batch([sequence])
|
||||
assert batch.block_table[sequence.request_id] == sequence.block_table_index
|
||||
sequence.status = RequsetStatus.COMPLETED
|
||||
batch.fliter_batch()
|
||||
assert batch.block_table == {}
|
||||
batch.add_seqs([sequence2])
|
||||
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
|
||||
assert batch.is_empty() == False
|
||||
assert batch.get_batch_size() == 3
|
||||
batch.update_batch_tokens([1, 2, 3])
|
||||
seq = batch.abort_seq(sequence)
|
||||
seq2 = batch.fliter_batch()[0]
|
||||
|
||||
assert batch.get_batch_size() == 1
|
||||
assert seq.get_output_len() == 1
|
||||
assert seq.output_token_id == [1]
|
||||
assert seq2.get_output_len() == 1
|
||||
assert seq2.output_token_id == [2]
|
||||
|
||||
batch.clear_batch()
|
||||
assert batch.block_table == {}
|
||||
assert batch.is_empty() == True
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_config_and_inference()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_config_and_inference():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_config_and_inferenceData()
|
||||
test_config_and_inference()
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def check_inference_engine():
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
inference_config = InferenceConfig()
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
|
||||
inputs = [
|
||||
"介绍一下北京",
|
||||
"介绍一下武汉",
|
||||
]
|
||||
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
outputs = inference_engine.generate(None)
|
||||
|
||||
for s1, s2 in zip(inputs, outputs):
|
||||
assert s1 == s2
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_inference_engine()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_inference_engine():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
|
@ -1,12 +1,14 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import parameterize, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
|
@ -64,7 +66,7 @@ def test_logical_blocks(test_config):
|
|||
},
|
||||
],
|
||||
)
|
||||
def test_cache_manager(test_config):
|
||||
def check_cache_manager(test_config):
|
||||
disable_existing_loggers()
|
||||
|
||||
assert test_config["max_batch_size"] > 1
|
||||
|
@ -78,7 +80,7 @@ def test_cache_manager(test_config):
|
|||
max_input_length = test_config["max_input_len"]
|
||||
max_output_length = test_config["max_output_len"]
|
||||
|
||||
inference_config = InferenceConfig(model="", **test_config)
|
||||
inference_config = InferenceConfig(**test_config)
|
||||
model_config = LlamaConfig(
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_layers,
|
||||
|
@ -147,6 +149,16 @@ def test_cache_manager(test_config):
|
|||
assert cache_manager.get_num_available_blocks() == num_blocks
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_cache_manager()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_cache_manager():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_logical_blocks()
|
||||
test_cache_manager()
|
||||
|
|
Loading…
Reference in New Issue