[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
pull/5258/head
yuehuayingxueluo 2023-12-18 10:40:47 +08:00 committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.distributed as dist
GibiByte = 1024**3
@ -15,44 +15,44 @@ class InferenceConfig:
"""The inference configuration.
Args:
model: Path or nn.Module of this model.
tokenizer: Path of the tokenizer to use.
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Whether to trust remote code from huggingface.
max_batch_size: Maximum batch size.
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
max_seq_len: Maximum length of input sentence.
quant_mode: Quantization mode.
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size.
max_output_len (int): Maximum output length.
max_input_len (int): Maximum input length.
block_size (int): The number of blocks in a logical block.
dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size.
pp_size (int): Pipeline parallel size.
max_seq_len (int): Maximum length of input sentence.
beam_width (int): The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
"""
model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = None
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
max_batch_size: int = 8
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1
pp_size: int = 1
max_seq_len: Optional[int] = None
max_seq_len: int = 512
# TODO: beam search is not support for now
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
quant_mode: Optional[str] = None
revision: Optional[str] = None
beam_width: int = 1
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
def __post_init__(self):
self._init_batch_size()
self._verify_config()
def _init_batch_size(self):
"""
@ -75,10 +75,20 @@ class InferenceConfig:
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)
def __post_init__(self):
self._init_batch_size()
self._verify_args()
def _verify_args(self):
if self.tokenizer_mode not in ["auto", "slow"]:
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
def _verify_config(self) -> None:
"""
Verify the input config
"""
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
"bf16",
torch.float32,
torch.float16,
torch.bfloat16,
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"

View File

@ -1,65 +1,232 @@
from logging import Logger
from typing import Optional
from itertools import count
from typing import List, Optional, Union
from transformers import AutoConfig
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
]
class InferenceEngine:
"""
InferenceEngine is the core component for Inference.
It is responsible for launch the inference process, including:
- Initialize model and distributed training environment(if needed)
- Launch request_handler and corresponding kv cache manager
- Receive requests and generate texts.
- Log the generation process
"""
InferenceEngine which manages the inference process..
Args:
tokenizer: Path of the tokenizer to use.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
model (nn.Module): Path or nn.Module of this model.
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
tokenizer: str = None,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
self._init_model()
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
self.inference_config = inference_config
self.model_config = model.config
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
)
self.verbose = verbose
if verbose:
self.logger = Logger()
self.logger = get_dist_logger(__name__)
def _init_model(self):
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.counter = count()
def _verify_config(self) -> None:
"""
Initialize model and distributed training environment(if needed).
May need to provide two different initialization methods:
1. 用户自定义(from local path)
2. 从checkpoint加载(hugging face)
Verify the input config
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: _description_
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def generate(
self,
generation_config: GenerationConfig = None,
) -> List[str]:
"""
Executing the inference step.
Args:
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
"""
def _verify_config(self):
self.generation_config = generation_config
output_list = []
while self.request_handler.check_unfinished_seqs():
output_list += self.step()
return output_list
def add_request(
self,
requests_id: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: List[int] = None,
) -> None:
"""
Verify the configuration to avoid potential bugs.
Add requests.
Args:
requests_id (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
def generate(self):
pass
block_size = self.inference_config.block_size
def step(self):
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = []
for prompt in prompts:
prompts_token_ids.append(self.tokenizer.encode(prompt))
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if requests_id:
request_id = requests_id[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run request_handler to update the kv cache and running input_ids
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Check whether there is finied request and decode
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
if self.verbose:
self.logger.info("Running generation step")
output_list = []
self.request_handler.schedule()
# Uncomment if the development of RequestHandler is completed.
# logits = self.model(batch)
# self.request_handler.search_tokens(logits, self.generation_config)
finished_sequences = self.request_handler.update()
# Decode completed sentences.
for seq in finished_sequences:
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
else:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
return output_list

View File

@ -1,5 +1,7 @@
from typing import List
from colossalai.inference.struct import BatchInfo, Sequence
class RequestHandler:
"""
@ -7,14 +9,17 @@ class RequestHandler:
During generation process, we call schedule function each iteration to update current batch.
Args:
cache_config: Configuration for initialize and manage kv cache.
inference_config: Store the configuration information related to inference.
model_config: The huggingface model config.
"""
def __init__(self, cache_config) -> None:
self.cache_config = cache_config
def __init__(self, inference_config, model_config) -> None:
self.inference_config = inference_config
self.model_config = model_config
self._init_cache()
self.waiting_list: List["Reqseq"] = []
self.running_list: List["Reqseq"] = []
self.waiting_list: List["Sequence"] = []
self.running_list: List["Sequence"] = []
self.batch = BatchInfo.init_batch()
def _init_cache(self):
"""
@ -25,12 +30,17 @@ class RequestHandler:
"""
The main logic of request handler.
"""
# The code below is only used for testing engine and will be modified.
if self.waiting_list:
self.running_list = self.waiting_list
self.batch.add_seqs(self.running_list)
return self.batch
def add_sequence(self, reqseq: "Reqseq"):
def add_sequence(self, req_seq: "Sequence"):
"""
Add the request to waiting list.
"""
self.waiting_list.append(reqseq)
self.waiting_list.append(req_seq)
def abort_sequence(self, seq_id: str):
"""
@ -39,10 +49,23 @@ class RequestHandler:
self._find_sequence(seq_id)
return
def _find_sequence(self, seq_id: str) -> "Reqseq":
def _find_sequence(self, seq_id: str) -> "Sequence":
"""
Find the request by seq_id.
"""
def check_unfinished_seqs(self) -> bool:
return self.waiting_list or self.running_list
return len(self.waiting_list) != 0 or len(self.running_list) != 0
def update(self):
"""
Update the waiting list and running list.
"""
# The code below is only used for testing engine and will be modified.
self.waiting_list = []
self.running_list = []
finished_sequences = list(self.batch.sequences_set)
self.batch.clear_batch()
return finished_sequences

View File

@ -135,7 +135,7 @@ class KVCacheManager:
and updates the provided block table with the allocated block ids.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece.
"""
assert block_table.dim() == 1
@ -185,7 +185,7 @@ class KVCacheManager:
and updates the provided block table if a new cache block is needed.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece (already-allocated length).
"""
assert block_table.dim() == 1
@ -199,7 +199,7 @@ class KVCacheManager:
and updates the provided block table with the allocated block.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for.
Returns:

View File

@ -0,0 +1,7 @@
from .llama import LlamaModelInferPolicy
model_policy_map = {
"llama": LlamaModelInferPolicy,
}
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]

View File

@ -0,0 +1,7 @@
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
# The code here just for test and will be modified later.
def __init__(self) -> None:
super().__init__()

View File

@ -1,68 +1,82 @@
import enum
from dataclasses import dataclass
from typing import Dict, List, Set
from typing import List, Union
import torch
from ordered_set import OrderedSet
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
"""
The abstraction of request and sequence are defined here.
"""
class RequsetStatus(enum.Enum):
"""The status of Sentences"""
class RequestStatus(enum.Enum):
"""
The status of Sentences
"""
# running status
WAITING = enum.auto()
RUNNING = enum.auto()
PREFILL = enum.auto()
TOKEN = enum.auto()
ABORTED = enum.auto()
# completion status
OVERLENGTH = enum.auto()
COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto()
@staticmethod
def is_finished(status: "RequsetStatus") -> bool:
def is_finished(status: "RequestStatus") -> bool:
return status in [
RequsetStatus.OVERLENGTH,
RequsetStatus.COMPLETED,
RequsetStatus.LENGTH_CAPPED,
RequestStatus.OVERLENGTH,
RequestStatus.COMPLETED,
RequestStatus.LENGTH_CAPPED,
]
@staticmethod
def is_running(status: "RequsetStatus") -> bool:
return status == RequsetStatus.RUNNING
def is_running(status: "RequestStatus") -> bool:
return status in [
RequestStatus.PREFILL,
RequestStatus.TOKEN,
]
@staticmethod
def is_waiting(status: "RequsetStatus") -> bool:
return status == RequsetStatus.WAITING
def is_waiting(status: "RequestStatus") -> bool:
return status == RequestStatus.WAITING
@dataclass
class Sequence:
"""Store information of input sequence.
Args:
request_id: The ID of input sequence.
prompt: The prompt of input sequence.
token_id: The tokens ID of input sequence.
block_size: The block size of input sequence.
sample_params: The sample_params of input sequence.
block_table_index: The index of input sequence in block_table.
request_id (int): The ID of input sequence.
prompt (str): The prompt of input sequence.
input_token_id (List[int]): The tokens ID of input sequence.
block_size (int): The block size of input sequence.
sample_params (SampleParams): The sample_params of input sequence.
block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process.
max_output_len (int): Maximum output length.
"""
def __init__(
self,
request_id: int,
prompt: str,
token_id: List[int],
block_size: int,
sample_params, # SampleParams needs to be imported later.
block_table_index: int,
):
self.request_id = request_id
self.prompt = prompt
self.input_token_id = token_id
self.blokc_size = block_size
self.sample_params = sample_params
request_id: int
prompt: str
input_token_id: List[int]
block_size: int
sample_params: any # SampleParams needs to be imported later.
block_table: torch.Tensor
eos_token_id: int
max_output_len: int = 256
def __post_init__(self):
self.output_token_id = []
self.status = RequsetStatus.WAITING
self.block_table_index = block_table_index
self.status = RequestStatus.WAITING
def get_sentence_len(self) -> None:
"""
@ -84,17 +98,30 @@ class Sequence:
def check_finish(self) -> bool:
"""
Check whether inference is over.
Check whether the inference is finished.
Returns:
bool: Whether the inference is finished.
"""
return RequsetStatus.is_finished(self.status)
if RequestStatus.is_finished(self.status):
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
return False
def __hash__(self):
return hash(self.request_id)
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self._logical_blocks)}"
f"sample_params={self.sample_params}"
)
@ -104,34 +131,38 @@ class BatchInfo:
Information to be passed and used for a batch of sequences.
"""
sequences_set: Set[Sequence]
block_table: Dict[int, int] = None
sequences_set: OrderedSet["Sequence"]
@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.
Args:
seqs (List[Sequence]): List of input sequence.
seqs (List["Sequence"]): List of input sequence.
"""
sequences_set = set()
block_table = {}
for seq in seqs:
if seq in sequences_set:
assert (
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set = OrderedSet()
sequences_set.add(seq)
block_table[seq.request_id] = seq.block_table_index
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
return cls(sequences_set=sequences_set, block_table=block_table)
sequences_set.add(seq)
return cls(sequences_set=sequences_set)
def get_block_table_tensor(self):
tesnor_list = []
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
return torch.concat(tesnor_list)
def clear_batch(self) -> None:
"""
@ -139,35 +170,76 @@ class BatchInfo:
"""
for seq in self.sequences_set:
if not seq.check_finish():
seq.status = RequsetStatus.ABORTED
seq.status = RequestStatus.ABORTED
self.sequences_set.clear()
self.block_table.clear()
def fliter_batch(self) -> None:
def fliter_batch(self) -> List["Sequence"]:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None:
Returns:
List["Sequence"]: List of finished sequences.
"""
finish_seqs = []
for seq in self.sequences_set:
if seq.check_finish():
finish_seqs.append(seq)
for finish_seq in finish_seqs:
self.sequences_set.discard(finish_seq)
return finish_seqs
def abort_seq(self, seq: "Sequence") -> "Sequence":
"""
Remove sequence from the batch.
"""
if not seq.check_finish():
seq.status = RequestStatus.ABORTED
self.sequences_set.discard(seq)
return seq
def add_seqs(self, seqs: List["Sequence"]) -> None:
"""
Add new sequence to batch
Args:
seqs (List[Sequence]): The list of new sequences.
seqs (List["Sequence"]): The list of new sequences.
"""
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in self.block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
assert (
seq.request_id not in self.block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
self.sequences_set.add(seq)
self.block_table[seq.request_id] = seq.block_table_index
def is_empty(self) -> None:
"""
Check whether sequences_set is empty.
"""
return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
"""
Add an output token for each sentence in the batch.
Args:
tokens (List[int]): A batch of tokens
"""
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens):
if not isinstance(token, list):
if not isinstance(token, int):
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
token = [token]
seq.output_token_id += token
seq.check_finish()
def get_batch_size(self) -> int:
"""
Get batch_size of this batch
"""
return len(self.sequences_set)

View File

@ -1,4 +1,5 @@
ordered_set
transformers==4.34.0
auto-gptq==0.5.0
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9

View File

@ -1,4 +1,6 @@
diffusers
fbgemm-gpu==0.2.0
ordered_set
pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon

0
tests/test_infer/_utils.py Normal file → Executable file
View File

70
tests/test_infer/test_config_and_struct.py Normal file → Executable file
View File

@ -1,26 +1,45 @@
import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
def test_config_and_inferenceData():
config = InferenceConfig("/llama")
assert config.max_batch_size
def check_config_and_inference():
config = InferenceConfig()
assert config.max_batch_size == 8
sequence = Sequence(
request_id=1,
prompt="abc",
token_id=[1, 2, 3],
input_token_id=[1, 2, 3],
block_size=16,
sample_params=None,
block_table_index=1,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
sequence2 = Sequence(
request_id=2,
prompt="bcd",
token_id=[4, 5, 6],
input_token_id=[4, 5, 6],
block_size=16,
sample_params=None,
block_table_index=2,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
sequence3 = Sequence(
request_id=3,
prompt="efg",
input_token_id=[7, 8, 9],
block_size=16,
sample_params=None,
block_table=None,
eos_token_id=2,
max_output_len=256,
)
assert sequence.get_sentence_len() == 3
@ -29,15 +48,34 @@ def test_config_and_inferenceData():
assert sequence.check_finish() == False
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence])
assert batch.is_empty() == False
assert batch.get_batch_size() == 3
batch.update_batch_tokens([1, 2, 3])
seq = batch.abort_seq(sequence)
seq2 = batch.fliter_batch()[0]
assert batch.get_batch_size() == 1
assert seq.get_output_len() == 1
assert seq.output_token_id == [1]
assert seq2.get_output_len() == 1
assert seq2.output_token_id == [2]
batch.clear_batch()
assert batch.block_table == {}
assert batch.is_empty() == True
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()
@pytest.mark.dist
def test_config_and_inference():
spawn(run_dist, 1)
if __name__ == "__main__":
test_config_and_inferenceData()
test_config_and_inference()

View File

@ -0,0 +1,44 @@
import pytest
import transformers
from transformers import AutoTokenizer
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import spawn
def check_inference_engine():
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
inference_config = InferenceConfig()
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inputs = [
"介绍一下北京",
"介绍一下武汉",
]
inference_engine.add_request(prompts=inputs)
outputs = inference_engine.generate(None)
for s1, s2 in zip(inputs, outputs):
assert s1 == s2
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_inference_engine()
@pytest.mark.dist
def test_inference_engine():
spawn(run_dist, 1)
if __name__ == "__main__":
test_inference_engine()

18
tests/test_infer/test_kvcache_manager.py Normal file → Executable file
View File

@ -1,12 +1,14 @@
import random
import pytest
import torch
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize
from colossalai.testing import parameterize, spawn
@parameterize(
@ -64,7 +66,7 @@ def test_logical_blocks(test_config):
},
],
)
def test_cache_manager(test_config):
def check_cache_manager(test_config):
disable_existing_loggers()
assert test_config["max_batch_size"] > 1
@ -78,7 +80,7 @@ def test_cache_manager(test_config):
max_input_length = test_config["max_input_len"]
max_output_length = test_config["max_output_len"]
inference_config = InferenceConfig(model="", **test_config)
inference_config = InferenceConfig(**test_config)
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
@ -147,6 +149,16 @@ def test_cache_manager(test_config):
assert cache_manager.get_num_available_blocks() == num_blocks
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()
@pytest.mark.dist
def test_cache_manager():
spawn(run_dist, 1)
if __name__ == "__main__":
test_logical_blocks()
test_cache_manager()