[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
pull/5258/head
yuehuayingxueluo 2023-12-18 10:40:47 +08:00 committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.distributed as dist
GibiByte = 1024**3 GibiByte = 1024**3
@ -15,44 +15,44 @@ class InferenceConfig:
"""The inference configuration. """The inference configuration.
Args: Args:
model: Path or nn.Module of this model. micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
tokenizer: Path of the tokenizer to use. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. max_batch_size (int): Maximum batch size.
trust_remote_code: Whether to trust remote code from huggingface. max_output_len (int): Maximum output length.
max_batch_size: Maximum batch size. max_input_len (int): Maximum input length.
max_output_len: Maximum output length. block_size (int): The number of blocks in a logical block.
max_input_len: Maximum input length. dtype (Union[str, torch.dtype]): The data type for weights and activations.
block_size: The number of blocks in a logical block. tp_size (int): Tensor parallel size.
dtype: The data type for weights and activations. pp_size (int): Pipeline parallel size.
tp_size: Tensor parallel size. max_seq_len (int): Maximum length of input sentence.
pp_size: Pipeline parallel size. beam_width (int): The maximum beam width used to initialize KV Cache.
max_seq_len: Maximum length of input sentence.
quant_mode: Quantization mode.
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio. when the actual value exceeds this ratio.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
""" """
model: Union[str, nn.Module] micro_batch_size: int = 1
tokenizer: str = None micro_batch_buffer_size: int = None
tokenizer_mode: str = "auto" max_batch_size: int = 8
trust_remote_code: bool = False
max_batch_size: int = None
max_output_len: int = 256 max_output_len: int = 256
max_input_len: int = 256 max_input_len: int = 256
block_size: int = 16 block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float32 dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1 tp_size: int = 1
pp_size: int = 1 pp_size: int = 1
max_seq_len: Optional[int] = None max_seq_len: int = 512
# TODO: beam search is not support for now
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
quant_mode: Optional[str] = None quant_mode: Optional[str] = None
revision: Optional[str] = None revision: Optional[str] = None
beam_width: int = 1
# TODO: beam search is not support for now def __post_init__(self):
prefill_ratio: Optional[float] = 1.2 self._init_batch_size()
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio self._verify_config()
def _init_batch_size(self): def _init_batch_size(self):
""" """
@ -75,10 +75,20 @@ class InferenceConfig:
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user." f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
) )
def __post_init__(self): def _verify_config(self) -> None:
self._init_batch_size() """
self._verify_args() Verify the input config
"""
def _verify_args(self): assert (
if self.tokenizer_mode not in ["auto", "slow"]: self.tp_size * self.pp_size == dist.get_world_size()
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}") ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
"bf16",
torch.float32,
torch.float16,
torch.bfloat16,
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"

View File

@ -1,65 +1,232 @@
from logging import Logger from itertools import count
from typing import Optional from typing import List, Optional, Union
from transformers import AutoConfig import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
]
class InferenceEngine: class InferenceEngine:
"""
InferenceEngine is the core component for Inference.
It is responsible for launch the inference process, including: """
- Initialize model and distributed training environment(if needed) InferenceEngine which manages the inference process..
- Launch request_handler and corresponding kv cache manager
- Receive requests and generate texts.
- Log the generation process
Args: Args:
tokenizer: Path of the tokenizer to use. model (nn.Module): Path or nn.Module of this model.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process. verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
""" """
def __init__( def __init__(
self, self,
tokenizer: str = None, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None, inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False, verbose: bool = False,
model_policy: Policy = None,
) -> None: ) -> None:
assert inference_config, "Please provide inference_config." assert inference_config, "Please provide inference_config."
self._init_model()
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained( self.inference_config = inference_config
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision self.model_config = model.config
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
) )
self.verbose = verbose
if verbose: if verbose:
self.logger = Logger() self.logger = get_dist_logger(__name__)
def _init_model(self): self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.counter = count()
def _verify_config(self) -> None:
""" """
Initialize model and distributed training environment(if needed). Verify the input config
May need to provide two different initialization methods: """
1. 用户自定义(from local path) if not isinstance(self.model, nn.Module):
2. 从checkpoint加载(hugging face) raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: _description_
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def generate(
self,
generation_config: GenerationConfig = None,
) -> List[str]:
"""
Executing the inference step.
Args:
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
""" """
def _verify_config(self): self.generation_config = generation_config
output_list = []
while self.request_handler.check_unfinished_seqs():
output_list += self.step()
return output_list
def add_request(
self,
requests_id: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: List[int] = None,
) -> None:
""" """
Verify the configuration to avoid potential bugs. Add requests.
Args:
requests_id (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
""" """
def generate(self): block_size = self.inference_config.block_size
pass
def step(self): if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = []
for prompt in prompts:
prompts_token_ids.append(self.tokenizer.encode(prompt))
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if requests_id:
request_id = requests_id[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
def step(self) -> List[str]:
""" """
In each step, do the follows: In each step, do the follows:
1. Run request_handler to update the kv cache and running input_ids 1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token 2. Run model to generate the next token
3. Check whether there is finied request and decode 3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
""" """
if self.verbose:
self.logger.info("Running generation step")
output_list = []
self.request_handler.schedule()
# Uncomment if the development of RequestHandler is completed.
# logits = self.model(batch)
# self.request_handler.search_tokens(logits, self.generation_config)
finished_sequences = self.request_handler.update()
# Decode completed sentences.
for seq in finished_sequences:
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
else:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
return output_list

View File

@ -1,5 +1,7 @@
from typing import List from typing import List
from colossalai.inference.struct import BatchInfo, Sequence
class RequestHandler: class RequestHandler:
""" """
@ -7,14 +9,17 @@ class RequestHandler:
During generation process, we call schedule function each iteration to update current batch. During generation process, we call schedule function each iteration to update current batch.
Args: Args:
cache_config: Configuration for initialize and manage kv cache. inference_config: Store the configuration information related to inference.
model_config: The huggingface model config.
""" """
def __init__(self, cache_config) -> None: def __init__(self, inference_config, model_config) -> None:
self.cache_config = cache_config self.inference_config = inference_config
self.model_config = model_config
self._init_cache() self._init_cache()
self.waiting_list: List["Reqseq"] = [] self.waiting_list: List["Sequence"] = []
self.running_list: List["Reqseq"] = [] self.running_list: List["Sequence"] = []
self.batch = BatchInfo.init_batch()
def _init_cache(self): def _init_cache(self):
""" """
@ -25,12 +30,17 @@ class RequestHandler:
""" """
The main logic of request handler. The main logic of request handler.
""" """
# The code below is only used for testing engine and will be modified.
if self.waiting_list:
self.running_list = self.waiting_list
self.batch.add_seqs(self.running_list)
return self.batch
def add_sequence(self, reqseq: "Reqseq"): def add_sequence(self, req_seq: "Sequence"):
""" """
Add the request to waiting list. Add the request to waiting list.
""" """
self.waiting_list.append(reqseq) self.waiting_list.append(req_seq)
def abort_sequence(self, seq_id: str): def abort_sequence(self, seq_id: str):
""" """
@ -39,10 +49,23 @@ class RequestHandler:
self._find_sequence(seq_id) self._find_sequence(seq_id)
return return
def _find_sequence(self, seq_id: str) -> "Reqseq": def _find_sequence(self, seq_id: str) -> "Sequence":
""" """
Find the request by seq_id. Find the request by seq_id.
""" """
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self.waiting_list or self.running_list return len(self.waiting_list) != 0 or len(self.running_list) != 0
def update(self):
"""
Update the waiting list and running list.
"""
# The code below is only used for testing engine and will be modified.
self.waiting_list = []
self.running_list = []
finished_sequences = list(self.batch.sequences_set)
self.batch.clear_batch()
return finished_sequences

View File

@ -135,7 +135,7 @@ class KVCacheManager:
and updates the provided block table with the allocated block ids. and updates the provided block table with the allocated block ids.
Args: Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece. context_len: The length of the processing sequnece.
""" """
assert block_table.dim() == 1 assert block_table.dim() == 1
@ -185,7 +185,7 @@ class KVCacheManager:
and updates the provided block table if a new cache block is needed. and updates the provided block table if a new cache block is needed.
Args: Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece (already-allocated length). context_len: The length of the processing sequnece (already-allocated length).
""" """
assert block_table.dim() == 1 assert block_table.dim() == 1
@ -199,7 +199,7 @@ class KVCacheManager:
and updates the provided block table with the allocated block. and updates the provided block table with the allocated block.
Args: Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table. block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for. space_asked: i.e. The number of tokens to be assigned space for.
Returns: Returns:

View File

@ -0,0 +1,7 @@
from .llama import LlamaModelInferPolicy
model_policy_map = {
"llama": LlamaModelInferPolicy,
}
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]

View File

@ -0,0 +1,7 @@
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
# The code here just for test and will be modified later.
def __init__(self) -> None:
super().__init__()

View File

@ -1,68 +1,82 @@
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Set from typing import List, Union
import torch
from ordered_set import OrderedSet
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
""" """
The abstraction of request and sequence are defined here. The abstraction of request and sequence are defined here.
""" """
class RequsetStatus(enum.Enum): class RequestStatus(enum.Enum):
"""The status of Sentences""" """
The status of Sentences
"""
# running status
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() PREFILL = enum.auto()
TOKEN = enum.auto()
ABORTED = enum.auto() ABORTED = enum.auto()
# completion status
OVERLENGTH = enum.auto() OVERLENGTH = enum.auto()
COMPLETED = enum.auto() COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto() LENGTH_CAPPED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "RequsetStatus") -> bool: def is_finished(status: "RequestStatus") -> bool:
return status in [ return status in [
RequsetStatus.OVERLENGTH, RequestStatus.OVERLENGTH,
RequsetStatus.COMPLETED, RequestStatus.COMPLETED,
RequsetStatus.LENGTH_CAPPED, RequestStatus.LENGTH_CAPPED,
] ]
@staticmethod @staticmethod
def is_running(status: "RequsetStatus") -> bool: def is_running(status: "RequestStatus") -> bool:
return status == RequsetStatus.RUNNING return status in [
RequestStatus.PREFILL,
RequestStatus.TOKEN,
]
@staticmethod @staticmethod
def is_waiting(status: "RequsetStatus") -> bool: def is_waiting(status: "RequestStatus") -> bool:
return status == RequsetStatus.WAITING return status == RequestStatus.WAITING
@dataclass
class Sequence: class Sequence:
"""Store information of input sequence. """Store information of input sequence.
Args: Args:
request_id: The ID of input sequence. request_id (int): The ID of input sequence.
prompt: The prompt of input sequence. prompt (str): The prompt of input sequence.
token_id: The tokens ID of input sequence. input_token_id (List[int]): The tokens ID of input sequence.
block_size: The block size of input sequence. block_size (int): The block size of input sequence.
sample_params: The sample_params of input sequence. sample_params (SampleParams): The sample_params of input sequence.
block_table_index: The index of input sequence in block_table. block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process.
max_output_len (int): Maximum output length.
""" """
def __init__( request_id: int
self, prompt: str
request_id: int, input_token_id: List[int]
prompt: str, block_size: int
token_id: List[int], sample_params: any # SampleParams needs to be imported later.
block_size: int, block_table: torch.Tensor
sample_params, # SampleParams needs to be imported later. eos_token_id: int
block_table_index: int, max_output_len: int = 256
):
self.request_id = request_id def __post_init__(self):
self.prompt = prompt
self.input_token_id = token_id
self.blokc_size = block_size
self.sample_params = sample_params
self.output_token_id = [] self.output_token_id = []
self.status = RequsetStatus.WAITING self.status = RequestStatus.WAITING
self.block_table_index = block_table_index
def get_sentence_len(self) -> None: def get_sentence_len(self) -> None:
""" """
@ -84,17 +98,30 @@ class Sequence:
def check_finish(self) -> bool: def check_finish(self) -> bool:
""" """
Check whether inference is over. Check whether the inference is finished.
Returns:
bool: Whether the inference is finished.
""" """
return RequsetStatus.is_finished(self.status) if RequestStatus.is_finished(self.status):
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
return False
def __hash__(self):
return hash(self.request_id)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"Request ID(request_id={self.request_id}, " f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, " f"prompt={self.prompt}, "
f"status={self.status.name}, " f"status={self.status.name}, "
f"sample_params={self.sample_params}, " f"sample_params={self.sample_params}"
f"logical block number={len(self._logical_blocks)}"
) )
@ -104,34 +131,38 @@ class BatchInfo:
Information to be passed and used for a batch of sequences. Information to be passed and used for a batch of sequences.
""" """
sequences_set: Set[Sequence] sequences_set: OrderedSet["Sequence"]
block_table: Dict[int, int] = None
@classmethod @classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo": def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
""" """
Initializes inference batches by input sentence list. Initializes inference batches by input sentence list.
Args: Args:
seqs (List[Sequence]): List of input sequence. seqs (List["Sequence"]): List of input sequence.
""" """
sequences_set = set()
block_table = {}
for seq in seqs:
if seq in sequences_set:
assert (
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert ( sequences_set = OrderedSet()
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set.add(seq) if seqs is not None:
block_table[seq.request_id] = seq.block_table_index if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
return cls(sequences_set=sequences_set, block_table=block_table) sequences_set.add(seq)
return cls(sequences_set=sequences_set)
def get_block_table_tensor(self):
tesnor_list = []
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
return torch.concat(tesnor_list)
def clear_batch(self) -> None: def clear_batch(self) -> None:
""" """
@ -139,35 +170,76 @@ class BatchInfo:
""" """
for seq in self.sequences_set: for seq in self.sequences_set:
if not seq.check_finish(): if not seq.check_finish():
seq.status = RequsetStatus.ABORTED seq.status = RequestStatus.ABORTED
self.sequences_set.clear() self.sequences_set.clear()
self.block_table.clear()
def fliter_batch(self) -> None: def fliter_batch(self) -> List["Sequence"]:
""" """
Remove completed sentences from a batch. Remove completed sentences from a batch.
"""
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None: Returns:
List["Sequence"]: List of finished sequences.
"""
finish_seqs = []
for seq in self.sequences_set:
if seq.check_finish():
finish_seqs.append(seq)
for finish_seq in finish_seqs:
self.sequences_set.discard(finish_seq)
return finish_seqs
def abort_seq(self, seq: "Sequence") -> "Sequence":
"""
Remove sequence from the batch.
"""
if not seq.check_finish():
seq.status = RequestStatus.ABORTED
self.sequences_set.discard(seq)
return seq
def add_seqs(self, seqs: List["Sequence"]) -> None:
""" """
Add new sequence to batch Add new sequence to batch
Args: Args:
seqs (List[Sequence]): The list of new sequences. seqs (List["Sequence"]): The list of new sequences.
""" """
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs: for seq in seqs:
if seq in self.sequences_set: if seq in self.sequences_set:
print("The sequence is already in sequences_set.") logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
assert (
seq.request_id in self.block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue continue
assert (
seq.request_id not in self.block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
self.sequences_set.add(seq) self.sequences_set.add(seq)
self.block_table[seq.request_id] = seq.block_table_index
def is_empty(self) -> None:
"""
Check whether sequences_set is empty.
"""
return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
"""
Add an output token for each sentence in the batch.
Args:
tokens (List[int]): A batch of tokens
"""
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens):
if not isinstance(token, list):
if not isinstance(token, int):
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
token = [token]
seq.output_token_id += token
seq.check_finish()
def get_batch_size(self) -> int:
"""
Get batch_size of this batch
"""
return len(self.sequences_set)

View File

@ -1,4 +1,5 @@
ordered_set
transformers==4.34.0 transformers==4.34.0
auto-gptq==0.5.0 auto-gptq==0.5.0
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9

View File

@ -1,4 +1,6 @@
diffusers diffusers
fbgemm-gpu==0.2.0
ordered_set
pytest pytest
coverage==7.2.3 coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon git+https://github.com/hpcaitech/pytest-testmon

0
tests/test_infer/_utils.py Normal file → Executable file
View File

70
tests/test_infer/test_config_and_struct.py Normal file → Executable file
View File

@ -1,26 +1,45 @@
import pytest
import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
def test_config_and_inferenceData(): def check_config_and_inference():
config = InferenceConfig("/llama") config = InferenceConfig()
assert config.max_batch_size assert config.max_batch_size == 8
sequence = Sequence( sequence = Sequence(
request_id=1, request_id=1,
prompt="abc", prompt="abc",
token_id=[1, 2, 3], input_token_id=[1, 2, 3],
block_size=16, block_size=16,
sample_params=None, sample_params=None,
block_table_index=1, block_table=None,
eos_token_id=2,
max_output_len=256,
) )
sequence2 = Sequence( sequence2 = Sequence(
request_id=2, request_id=2,
prompt="bcd", prompt="bcd",
token_id=[4, 5, 6], input_token_id=[4, 5, 6],
block_size=16, block_size=16,
sample_params=None, sample_params=None,
block_table_index=2, block_table=None,
eos_token_id=2,
max_output_len=256,
)
sequence3 = Sequence(
request_id=3,
prompt="efg",
input_token_id=[7, 8, 9],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
max_output_len=256,
) )
assert sequence.get_sentence_len() == 3 assert sequence.get_sentence_len() == 3
@ -29,15 +48,34 @@ def test_config_and_inferenceData():
assert sequence.check_finish() == False assert sequence.check_finish() == False
batch = BatchInfo.init_batch([sequence]) batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index batch.add_seqs([sequence2, sequence3])
sequence.status = RequsetStatus.COMPLETED batch.add_seqs([sequence])
batch.fliter_batch()
assert batch.block_table == {} assert batch.is_empty() == False
batch.add_seqs([sequence2]) assert batch.get_batch_size() == 3
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index batch.update_batch_tokens([1, 2, 3])
seq = batch.abort_seq(sequence)
seq2 = batch.fliter_batch()[0]
assert batch.get_batch_size() == 1
assert seq.get_output_len() == 1
assert seq.output_token_id == [1]
assert seq2.get_output_len() == 1
assert seq2.output_token_id == [2]
batch.clear_batch() batch.clear_batch()
assert batch.block_table == {} assert batch.is_empty() == True
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()
@pytest.mark.dist
def test_config_and_inference():
spawn(run_dist, 1)
if __name__ == "__main__": if __name__ == "__main__":
test_config_and_inferenceData() test_config_and_inference()

View File

@ -0,0 +1,44 @@
import pytest
import transformers
from transformers import AutoTokenizer
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import spawn
def check_inference_engine():
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
inference_config = InferenceConfig()
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inputs = [
"介绍一下北京",
"介绍一下武汉",
]
inference_engine.add_request(prompts=inputs)
outputs = inference_engine.generate(None)
for s1, s2 in zip(inputs, outputs):
assert s1 == s2
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_inference_engine()
@pytest.mark.dist
def test_inference_engine():
spawn(run_dist, 1)
if __name__ == "__main__":
test_inference_engine()

18
tests/test_infer/test_kvcache_manager.py Normal file → Executable file
View File

@ -1,12 +1,14 @@
import random import random
import pytest
import torch import torch
from transformers.models.llama import LlamaConfig from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize from colossalai.testing import parameterize, spawn
@parameterize( @parameterize(
@ -64,7 +66,7 @@ def test_logical_blocks(test_config):
}, },
], ],
) )
def test_cache_manager(test_config): def check_cache_manager(test_config):
disable_existing_loggers() disable_existing_loggers()
assert test_config["max_batch_size"] > 1 assert test_config["max_batch_size"] > 1
@ -78,7 +80,7 @@ def test_cache_manager(test_config):
max_input_length = test_config["max_input_len"] max_input_length = test_config["max_input_len"]
max_output_length = test_config["max_output_len"] max_output_length = test_config["max_output_len"]
inference_config = InferenceConfig(model="", **test_config) inference_config = InferenceConfig(**test_config)
model_config = LlamaConfig( model_config = LlamaConfig(
hidden_size=hidden_size, hidden_size=hidden_size,
num_hidden_layers=num_layers, num_hidden_layers=num_layers,
@ -147,6 +149,16 @@ def test_cache_manager(test_config):
assert cache_manager.get_num_available_blocks() == num_blocks assert cache_manager.get_num_available_blocks() == num_blocks
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()
@pytest.mark.dist
def test_cache_manager():
spawn(run_dist, 1)
if __name__ == "__main__": if __name__ == "__main__":
test_logical_blocks() test_logical_blocks()
test_cache_manager() test_cache_manager()