mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txtpull/5258/head
parent
93aeacca34
commit
8daee26989
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.distributed as dist
|
||||||
|
|
||||||
GibiByte = 1024**3
|
GibiByte = 1024**3
|
||||||
|
|
||||||
|
@ -15,44 +15,44 @@ class InferenceConfig:
|
||||||
"""The inference configuration.
|
"""The inference configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Path or nn.Module of this model.
|
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||||
tokenizer: Path of the tokenizer to use.
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
|
max_batch_size (int): Maximum batch size.
|
||||||
trust_remote_code: Whether to trust remote code from huggingface.
|
max_output_len (int): Maximum output length.
|
||||||
max_batch_size: Maximum batch size.
|
max_input_len (int): Maximum input length.
|
||||||
max_output_len: Maximum output length.
|
block_size (int): The number of blocks in a logical block.
|
||||||
max_input_len: Maximum input length.
|
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||||
block_size: The number of blocks in a logical block.
|
tp_size (int): Tensor parallel size.
|
||||||
dtype: The data type for weights and activations.
|
pp_size (int): Pipeline parallel size.
|
||||||
tp_size: Tensor parallel size.
|
max_seq_len (int): Maximum length of input sentence.
|
||||||
pp_size: Pipeline parallel size.
|
beam_width (int): The maximum beam width used to initialize KV Cache.
|
||||||
max_seq_len: Maximum length of input sentence.
|
|
||||||
quant_mode: Quantization mode.
|
|
||||||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
|
||||||
beam_width: The maximum beam width used to initialize KV Cache.
|
|
||||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||||
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||||
when the actual value exceeds this ratio.
|
when the actual value exceeds this ratio.
|
||||||
|
quant_mode (Optional[str]): Quantization mode.
|
||||||
|
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: Union[str, nn.Module]
|
micro_batch_size: int = 1
|
||||||
tokenizer: str = None
|
micro_batch_buffer_size: int = None
|
||||||
tokenizer_mode: str = "auto"
|
max_batch_size: int = 8
|
||||||
trust_remote_code: bool = False
|
|
||||||
max_batch_size: int = None
|
|
||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
max_input_len: int = 256
|
max_input_len: int = 256
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
dtype: Union[str, torch.dtype] = torch.float32
|
dtype: Union[str, torch.dtype] = torch.float32
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
pp_size: int = 1
|
pp_size: int = 1
|
||||||
max_seq_len: Optional[int] = None
|
max_seq_len: int = 512
|
||||||
|
# TODO: beam search is not support for now
|
||||||
|
beam_width: int = 1
|
||||||
|
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
|
prefill_ratio: Optional[float] = 1.2
|
||||||
quant_mode: Optional[str] = None
|
quant_mode: Optional[str] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
beam_width: int = 1
|
|
||||||
# TODO: beam search is not support for now
|
def __post_init__(self):
|
||||||
prefill_ratio: Optional[float] = 1.2
|
self._init_batch_size()
|
||||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
self._verify_config()
|
||||||
|
|
||||||
def _init_batch_size(self):
|
def _init_batch_size(self):
|
||||||
"""
|
"""
|
||||||
|
@ -75,10 +75,20 @@ class InferenceConfig:
|
||||||
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def _verify_config(self) -> None:
|
||||||
self._init_batch_size()
|
"""
|
||||||
self._verify_args()
|
Verify the input config
|
||||||
|
"""
|
||||||
def _verify_args(self):
|
assert (
|
||||||
if self.tokenizer_mode not in ["auto", "slow"]:
|
self.tp_size * self.pp_size == dist.get_world_size()
|
||||||
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
|
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||||
|
assert self.dtype in [
|
||||||
|
"fp16",
|
||||||
|
"fp32",
|
||||||
|
"bf16",
|
||||||
|
torch.float32,
|
||||||
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||||
|
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||||
|
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||||
|
|
|
@ -1,65 +1,232 @@
|
||||||
from logging import Logger
|
from itertools import count
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from transformers import AutoConfig
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
|
from colossalai.inference.struct import Sequence
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
|
||||||
|
from .request_handler import RequestHandler
|
||||||
|
|
||||||
|
PP_AXIS, TP_AXIS = 0, 1
|
||||||
|
|
||||||
|
_supported_models = [
|
||||||
|
"LlamaForCausalLM",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""
|
|
||||||
InferenceEngine is the core component for Inference.
|
|
||||||
|
|
||||||
It is responsible for launch the inference process, including:
|
"""
|
||||||
- Initialize model and distributed training environment(if needed)
|
InferenceEngine which manages the inference process..
|
||||||
- Launch request_handler and corresponding kv cache manager
|
|
||||||
- Receive requests and generate texts.
|
|
||||||
- Log the generation process
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer: Path of the tokenizer to use.
|
model (nn.Module): Path or nn.Module of this model.
|
||||||
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
|
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
||||||
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||||
verbose (bool): Determine whether or not to log the generation process.
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: str = None,
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
inference_config: Optional["InferenceConfig"] = None,
|
inference_config: Optional["InferenceConfig"] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
model_policy: Policy = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert inference_config, "Please provide inference_config."
|
assert inference_config, "Please provide inference_config."
|
||||||
|
|
||||||
self._init_model()
|
|
||||||
# cache_config may need to be modified later.
|
|
||||||
# self.request_handler = RequestHandler(cache_config)
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.hf_model_config = AutoConfig.from_pretrained(
|
self.inference_config = inference_config
|
||||||
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
|
self.model_config = model.config
|
||||||
|
|
||||||
|
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||||
|
self.dtype = torch.float32
|
||||||
|
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||||
|
self.dtype = torch.float16
|
||||||
|
model.half()
|
||||||
|
else:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
if model_policy is None:
|
||||||
|
model_policy = model_policy_map[self.model_config.model_type]()
|
||||||
|
|
||||||
|
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||||
|
|
||||||
|
self.model = self._shardformer(
|
||||||
|
model,
|
||||||
|
model_policy,
|
||||||
|
None,
|
||||||
|
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
if verbose:
|
if verbose:
|
||||||
self.logger = Logger()
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
def _init_model(self):
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||||
|
self.counter = count()
|
||||||
|
|
||||||
|
def _verify_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize model and distributed training environment(if needed).
|
Verify the input config
|
||||||
May need to provide two different initialization methods:
|
"""
|
||||||
1. 用户自定义(from local path)
|
if not isinstance(self.model, nn.Module):
|
||||||
2. 从checkpoint加载(hugging face)
|
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||||
|
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||||
|
self.tokenizer, PreTrainedTokenizer
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.model.__class__.__name__ in _supported_models
|
||||||
|
), f"Model {self.model.__class__.__name__} is not supported."
|
||||||
|
|
||||||
|
def _shardformer(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
model_policy: Policy,
|
||||||
|
stage_manager: PipelineStageManager = None,
|
||||||
|
tp_group: ProcessGroupMesh = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Initialize ShardConfig and replace the model with shardformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Path or nn.Module of this model.
|
||||||
|
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||||
|
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||||
|
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: _description_
|
||||||
|
"""
|
||||||
|
shardconfig = ShardConfig(
|
||||||
|
tensor_parallel_process_group=tp_group,
|
||||||
|
pipeline_stage_manager=stage_manager,
|
||||||
|
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||||
|
enable_fused_normalization=False,
|
||||||
|
enable_all_optimization=False,
|
||||||
|
enable_flash_attention=False,
|
||||||
|
enable_jit_fused=False,
|
||||||
|
enable_sequence_parallelism=False,
|
||||||
|
extra_kwargs={"quant": self.inference_config.quant_mode},
|
||||||
|
)
|
||||||
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
return shard_model.cuda()
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
generation_config: GenerationConfig = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Executing the inference step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _verify_config(self):
|
self.generation_config = generation_config
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
|
||||||
|
while self.request_handler.check_unfinished_seqs():
|
||||||
|
output_list += self.step()
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
requests_id: List[int] = None,
|
||||||
|
prompts: List[str] = None,
|
||||||
|
prompts_token_ids: List[int] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Verify the configuration to avoid potential bugs.
|
Add requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests_id (List[int], optional): The request ID. Defaults to None.
|
||||||
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def generate(self):
|
block_size = self.inference_config.block_size
|
||||||
pass
|
|
||||||
|
|
||||||
def step(self):
|
if prompts_token_ids is None:
|
||||||
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||||
|
prompts_token_ids = []
|
||||||
|
for prompt in prompts:
|
||||||
|
prompts_token_ids.append(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
prompts_num = len(prompts_token_ids)
|
||||||
|
|
||||||
|
for i in range(prompts_num):
|
||||||
|
if requests_id:
|
||||||
|
request_id = requests_id[i]
|
||||||
|
else:
|
||||||
|
request_id = next(self.counter)
|
||||||
|
if prompts == None:
|
||||||
|
prompt = None
|
||||||
|
else:
|
||||||
|
prompt = prompts[i]
|
||||||
|
sequence = Sequence(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
prompts_token_ids[i],
|
||||||
|
block_size,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.tokenizer.eos_token_id,
|
||||||
|
self.inference_config.max_output_len,
|
||||||
|
)
|
||||||
|
self.request_handler.add_sequence(sequence)
|
||||||
|
|
||||||
|
def step(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
In each step, do the follows:
|
In each step, do the follows:
|
||||||
1. Run request_handler to update the kv cache and running input_ids
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||||
2. Run model to generate the next token
|
2. Run model to generate the next token
|
||||||
3. Check whether there is finied request and decode
|
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||||
|
4. Decode and return finished sequences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info("Running generation step")
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
self.request_handler.schedule()
|
||||||
|
|
||||||
|
# Uncomment if the development of RequestHandler is completed.
|
||||||
|
# logits = self.model(batch)
|
||||||
|
# self.request_handler.search_tokens(logits, self.generation_config)
|
||||||
|
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
# Decode completed sentences.
|
||||||
|
for seq in finished_sequences:
|
||||||
|
if seq.prompt:
|
||||||
|
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||||
|
output_list.append(seq.prompt + output_str)
|
||||||
|
else:
|
||||||
|
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||||
|
output_list.append(output_str)
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -7,14 +9,17 @@ class RequestHandler:
|
||||||
During generation process, we call schedule function each iteration to update current batch.
|
During generation process, we call schedule function each iteration to update current batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_config: Configuration for initialize and manage kv cache.
|
inference_config: Store the configuration information related to inference.
|
||||||
|
model_config: The huggingface model config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_config) -> None:
|
def __init__(self, inference_config, model_config) -> None:
|
||||||
self.cache_config = cache_config
|
self.inference_config = inference_config
|
||||||
|
self.model_config = model_config
|
||||||
self._init_cache()
|
self._init_cache()
|
||||||
self.waiting_list: List["Reqseq"] = []
|
self.waiting_list: List["Sequence"] = []
|
||||||
self.running_list: List["Reqseq"] = []
|
self.running_list: List["Sequence"] = []
|
||||||
|
self.batch = BatchInfo.init_batch()
|
||||||
|
|
||||||
def _init_cache(self):
|
def _init_cache(self):
|
||||||
"""
|
"""
|
||||||
|
@ -25,12 +30,17 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
The main logic of request handler.
|
The main logic of request handler.
|
||||||
"""
|
"""
|
||||||
|
# The code below is only used for testing engine and will be modified.
|
||||||
|
if self.waiting_list:
|
||||||
|
self.running_list = self.waiting_list
|
||||||
|
self.batch.add_seqs(self.running_list)
|
||||||
|
return self.batch
|
||||||
|
|
||||||
def add_sequence(self, reqseq: "Reqseq"):
|
def add_sequence(self, req_seq: "Sequence"):
|
||||||
"""
|
"""
|
||||||
Add the request to waiting list.
|
Add the request to waiting list.
|
||||||
"""
|
"""
|
||||||
self.waiting_list.append(reqseq)
|
self.waiting_list.append(req_seq)
|
||||||
|
|
||||||
def abort_sequence(self, seq_id: str):
|
def abort_sequence(self, seq_id: str):
|
||||||
"""
|
"""
|
||||||
|
@ -39,10 +49,23 @@ class RequestHandler:
|
||||||
self._find_sequence(seq_id)
|
self._find_sequence(seq_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
def _find_sequence(self, seq_id: str) -> "Reqseq":
|
def _find_sequence(self, seq_id: str) -> "Sequence":
|
||||||
"""
|
"""
|
||||||
Find the request by seq_id.
|
Find the request by seq_id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self.waiting_list or self.running_list
|
return len(self.waiting_list) != 0 or len(self.running_list) != 0
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""
|
||||||
|
Update the waiting list and running list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The code below is only used for testing engine and will be modified.
|
||||||
|
self.waiting_list = []
|
||||||
|
self.running_list = []
|
||||||
|
finished_sequences = list(self.batch.sequences_set)
|
||||||
|
|
||||||
|
self.batch.clear_batch()
|
||||||
|
return finished_sequences
|
||||||
|
|
|
@ -135,7 +135,7 @@ class KVCacheManager:
|
||||||
and updates the provided block table with the allocated block ids.
|
and updates the provided block table with the allocated block ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||||
context_len: The length of the processing sequnece.
|
context_len: The length of the processing sequnece.
|
||||||
"""
|
"""
|
||||||
assert block_table.dim() == 1
|
assert block_table.dim() == 1
|
||||||
|
@ -185,7 +185,7 @@ class KVCacheManager:
|
||||||
and updates the provided block table if a new cache block is needed.
|
and updates the provided block table if a new cache block is needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||||
context_len: The length of the processing sequnece (already-allocated length).
|
context_len: The length of the processing sequnece (already-allocated length).
|
||||||
"""
|
"""
|
||||||
assert block_table.dim() == 1
|
assert block_table.dim() == 1
|
||||||
|
@ -199,7 +199,7 @@ class KVCacheManager:
|
||||||
and updates the provided block table with the allocated block.
|
and updates the provided block table with the allocated block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
|
block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id.
|
||||||
block_local_idx: The index of the block in the block table.
|
block_local_idx: The index of the block in the block table.
|
||||||
space_asked: i.e. The number of tokens to be assigned space for.
|
space_asked: i.e. The number of tokens to be assigned space for.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .llama import LlamaModelInferPolicy
|
||||||
|
|
||||||
|
model_policy_map = {
|
||||||
|
"llama": LlamaModelInferPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]
|
|
@ -0,0 +1,7 @@
|
||||||
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
# The code here just for test and will be modified later.
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
|
@ -1,68 +1,82 @@
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Set
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ordered_set import OrderedSet
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The abstraction of request and sequence are defined here.
|
The abstraction of request and sequence are defined here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequsetStatus(enum.Enum):
|
class RequestStatus(enum.Enum):
|
||||||
"""The status of Sentences"""
|
"""
|
||||||
|
The status of Sentences
|
||||||
|
"""
|
||||||
|
|
||||||
|
# running status
|
||||||
WAITING = enum.auto()
|
WAITING = enum.auto()
|
||||||
RUNNING = enum.auto()
|
PREFILL = enum.auto()
|
||||||
|
TOKEN = enum.auto()
|
||||||
ABORTED = enum.auto()
|
ABORTED = enum.auto()
|
||||||
|
|
||||||
|
# completion status
|
||||||
OVERLENGTH = enum.auto()
|
OVERLENGTH = enum.auto()
|
||||||
COMPLETED = enum.auto()
|
COMPLETED = enum.auto()
|
||||||
LENGTH_CAPPED = enum.auto()
|
LENGTH_CAPPED = enum.auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_finished(status: "RequsetStatus") -> bool:
|
def is_finished(status: "RequestStatus") -> bool:
|
||||||
return status in [
|
return status in [
|
||||||
RequsetStatus.OVERLENGTH,
|
RequestStatus.OVERLENGTH,
|
||||||
RequsetStatus.COMPLETED,
|
RequestStatus.COMPLETED,
|
||||||
RequsetStatus.LENGTH_CAPPED,
|
RequestStatus.LENGTH_CAPPED,
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_running(status: "RequsetStatus") -> bool:
|
def is_running(status: "RequestStatus") -> bool:
|
||||||
return status == RequsetStatus.RUNNING
|
return status in [
|
||||||
|
RequestStatus.PREFILL,
|
||||||
|
RequestStatus.TOKEN,
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_waiting(status: "RequsetStatus") -> bool:
|
def is_waiting(status: "RequestStatus") -> bool:
|
||||||
return status == RequsetStatus.WAITING
|
return status == RequestStatus.WAITING
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Sequence:
|
class Sequence:
|
||||||
"""Store information of input sequence.
|
"""Store information of input sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The ID of input sequence.
|
request_id (int): The ID of input sequence.
|
||||||
prompt: The prompt of input sequence.
|
prompt (str): The prompt of input sequence.
|
||||||
token_id: The tokens ID of input sequence.
|
input_token_id (List[int]): The tokens ID of input sequence.
|
||||||
block_size: The block size of input sequence.
|
block_size (int): The block size of input sequence.
|
||||||
sample_params: The sample_params of input sequence.
|
sample_params (SampleParams): The sample_params of input sequence.
|
||||||
block_table_index: The index of input sequence in block_table.
|
block_table (torch.Tensor): The index of input sequence in block_table.
|
||||||
|
eos_token_id (int): The eos token id for this inference process.
|
||||||
|
max_output_len (int): Maximum output length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
request_id: int
|
||||||
self,
|
prompt: str
|
||||||
request_id: int,
|
input_token_id: List[int]
|
||||||
prompt: str,
|
block_size: int
|
||||||
token_id: List[int],
|
sample_params: any # SampleParams needs to be imported later.
|
||||||
block_size: int,
|
block_table: torch.Tensor
|
||||||
sample_params, # SampleParams needs to be imported later.
|
eos_token_id: int
|
||||||
block_table_index: int,
|
max_output_len: int = 256
|
||||||
):
|
|
||||||
self.request_id = request_id
|
def __post_init__(self):
|
||||||
self.prompt = prompt
|
|
||||||
self.input_token_id = token_id
|
|
||||||
self.blokc_size = block_size
|
|
||||||
self.sample_params = sample_params
|
|
||||||
self.output_token_id = []
|
self.output_token_id = []
|
||||||
self.status = RequsetStatus.WAITING
|
self.status = RequestStatus.WAITING
|
||||||
self.block_table_index = block_table_index
|
|
||||||
|
|
||||||
def get_sentence_len(self) -> None:
|
def get_sentence_len(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -84,17 +98,30 @@ class Sequence:
|
||||||
|
|
||||||
def check_finish(self) -> bool:
|
def check_finish(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether inference is over.
|
Check whether the inference is finished.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the inference is finished.
|
||||||
"""
|
"""
|
||||||
return RequsetStatus.is_finished(self.status)
|
if RequestStatus.is_finished(self.status):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.output_token_id:
|
||||||
|
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||||
|
self.status = RequestStatus.COMPLETED
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.request_id)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"Request ID(request_id={self.request_id}, "
|
f"Request ID(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt}, "
|
f"prompt={self.prompt}, "
|
||||||
f"status={self.status.name}, "
|
f"status={self.status.name}, "
|
||||||
f"sample_params={self.sample_params}, "
|
f"sample_params={self.sample_params}"
|
||||||
f"logical block number={len(self._logical_blocks)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,34 +131,38 @@ class BatchInfo:
|
||||||
Information to be passed and used for a batch of sequences.
|
Information to be passed and used for a batch of sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences_set: Set[Sequence]
|
sequences_set: OrderedSet["Sequence"]
|
||||||
block_table: Dict[int, int] = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
|
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||||
"""
|
"""
|
||||||
Initializes inference batches by input sentence list.
|
Initializes inference batches by input sentence list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seqs (List[Sequence]): List of input sequence.
|
seqs (List["Sequence"]): List of input sequence.
|
||||||
"""
|
"""
|
||||||
sequences_set = set()
|
|
||||||
block_table = {}
|
|
||||||
for seq in seqs:
|
|
||||||
if seq in sequences_set:
|
|
||||||
assert (
|
|
||||||
seq.request_id in block_table.keys()
|
|
||||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert (
|
sequences_set = OrderedSet()
|
||||||
seq.request_id not in block_table.keys()
|
|
||||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
|
||||||
|
|
||||||
sequences_set.add(seq)
|
if seqs is not None:
|
||||||
block_table[seq.request_id] = seq.block_table_index
|
if not isinstance(seqs, list):
|
||||||
|
seqs = [seqs]
|
||||||
|
for seq in seqs:
|
||||||
|
if seq in sequences_set:
|
||||||
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||||
|
continue
|
||||||
|
|
||||||
return cls(sequences_set=sequences_set, block_table=block_table)
|
sequences_set.add(seq)
|
||||||
|
|
||||||
|
return cls(sequences_set=sequences_set)
|
||||||
|
|
||||||
|
def get_block_table_tensor(self):
|
||||||
|
tesnor_list = []
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
block_table = seq.block_table
|
||||||
|
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||||
|
tesnor_list.append(seq.block_table)
|
||||||
|
return torch.concat(tesnor_list)
|
||||||
|
|
||||||
def clear_batch(self) -> None:
|
def clear_batch(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -139,35 +170,76 @@ class BatchInfo:
|
||||||
"""
|
"""
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
if not seq.check_finish():
|
if not seq.check_finish():
|
||||||
seq.status = RequsetStatus.ABORTED
|
seq.status = RequestStatus.ABORTED
|
||||||
self.sequences_set.clear()
|
self.sequences_set.clear()
|
||||||
self.block_table.clear()
|
|
||||||
|
|
||||||
def fliter_batch(self) -> None:
|
def fliter_batch(self) -> List["Sequence"]:
|
||||||
"""
|
"""
|
||||||
Remove completed sentences from a batch.
|
Remove completed sentences from a batch.
|
||||||
"""
|
|
||||||
for seq in self.sequences_set.copy():
|
|
||||||
if seq.check_finish():
|
|
||||||
self.sequences_set.remove(seq)
|
|
||||||
del self.block_table[seq.request_id]
|
|
||||||
|
|
||||||
def add_seqs(self, seqs: List[Sequence]) -> None:
|
Returns:
|
||||||
|
List["Sequence"]: List of finished sequences.
|
||||||
|
"""
|
||||||
|
finish_seqs = []
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
if seq.check_finish():
|
||||||
|
finish_seqs.append(seq)
|
||||||
|
for finish_seq in finish_seqs:
|
||||||
|
self.sequences_set.discard(finish_seq)
|
||||||
|
return finish_seqs
|
||||||
|
|
||||||
|
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
||||||
|
"""
|
||||||
|
Remove sequence from the batch.
|
||||||
|
"""
|
||||||
|
if not seq.check_finish():
|
||||||
|
seq.status = RequestStatus.ABORTED
|
||||||
|
self.sequences_set.discard(seq)
|
||||||
|
return seq
|
||||||
|
|
||||||
|
def add_seqs(self, seqs: List["Sequence"]) -> None:
|
||||||
"""
|
"""
|
||||||
Add new sequence to batch
|
Add new sequence to batch
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seqs (List[Sequence]): The list of new sequences.
|
seqs (List["Sequence"]): The list of new sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not isinstance(seqs, list):
|
||||||
|
seqs = [seqs]
|
||||||
|
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
if seq in self.sequences_set:
|
if seq in self.sequences_set:
|
||||||
print("The sequence is already in sequences_set.")
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||||
assert (
|
|
||||||
seq.request_id in self.block_table
|
|
||||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
|
||||||
continue
|
continue
|
||||||
assert (
|
|
||||||
seq.request_id not in self.block_table
|
|
||||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
|
||||||
self.sequences_set.add(seq)
|
self.sequences_set.add(seq)
|
||||||
self.block_table[seq.request_id] = seq.block_table_index
|
|
||||||
|
def is_empty(self) -> None:
|
||||||
|
"""
|
||||||
|
Check whether sequences_set is empty.
|
||||||
|
"""
|
||||||
|
return not self.sequences_set
|
||||||
|
|
||||||
|
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
||||||
|
"""
|
||||||
|
Add an output token for each sentence in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (List[int]): A batch of tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||||
|
|
||||||
|
for seq, token in zip(self.sequences_set, tokens):
|
||||||
|
if not isinstance(token, list):
|
||||||
|
if not isinstance(token, int):
|
||||||
|
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
|
||||||
|
token = [token]
|
||||||
|
seq.output_token_id += token
|
||||||
|
seq.check_finish()
|
||||||
|
|
||||||
|
def get_batch_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Get batch_size of this batch
|
||||||
|
"""
|
||||||
|
return len(self.sequences_set)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
|
ordered_set
|
||||||
transformers==4.34.0
|
transformers==4.34.0
|
||||||
auto-gptq==0.5.0
|
auto-gptq==0.5.0
|
||||||
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
||||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
|
@ -1,4 +1,6 @@
|
||||||
diffusers
|
diffusers
|
||||||
|
fbgemm-gpu==0.2.0
|
||||||
|
ordered_set
|
||||||
pytest
|
pytest
|
||||||
coverage==7.2.3
|
coverage==7.2.3
|
||||||
git+https://github.com/hpcaitech/pytest-testmon
|
git+https://github.com/hpcaitech/pytest-testmon
|
||||||
|
|
|
@ -1,26 +1,45 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
|
from colossalai.testing import spawn
|
||||||
|
|
||||||
|
|
||||||
def test_config_and_inferenceData():
|
def check_config_and_inference():
|
||||||
config = InferenceConfig("/llama")
|
config = InferenceConfig()
|
||||||
assert config.max_batch_size
|
assert config.max_batch_size == 8
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id=1,
|
request_id=1,
|
||||||
prompt="abc",
|
prompt="abc",
|
||||||
token_id=[1, 2, 3],
|
input_token_id=[1, 2, 3],
|
||||||
block_size=16,
|
block_size=16,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table_index=1,
|
block_table=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
max_output_len=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence2 = Sequence(
|
sequence2 = Sequence(
|
||||||
request_id=2,
|
request_id=2,
|
||||||
prompt="bcd",
|
prompt="bcd",
|
||||||
token_id=[4, 5, 6],
|
input_token_id=[4, 5, 6],
|
||||||
block_size=16,
|
block_size=16,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table_index=2,
|
block_table=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
max_output_len=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence3 = Sequence(
|
||||||
|
request_id=3,
|
||||||
|
prompt="efg",
|
||||||
|
input_token_id=[7, 8, 9],
|
||||||
|
block_size=16,
|
||||||
|
sample_params=None,
|
||||||
|
block_table=None,
|
||||||
|
eos_token_id=2,
|
||||||
|
max_output_len=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sequence.get_sentence_len() == 3
|
assert sequence.get_sentence_len() == 3
|
||||||
|
@ -29,15 +48,34 @@ def test_config_and_inferenceData():
|
||||||
assert sequence.check_finish() == False
|
assert sequence.check_finish() == False
|
||||||
|
|
||||||
batch = BatchInfo.init_batch([sequence])
|
batch = BatchInfo.init_batch([sequence])
|
||||||
assert batch.block_table[sequence.request_id] == sequence.block_table_index
|
batch.add_seqs([sequence2, sequence3])
|
||||||
sequence.status = RequsetStatus.COMPLETED
|
batch.add_seqs([sequence])
|
||||||
batch.fliter_batch()
|
|
||||||
assert batch.block_table == {}
|
assert batch.is_empty() == False
|
||||||
batch.add_seqs([sequence2])
|
assert batch.get_batch_size() == 3
|
||||||
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
|
batch.update_batch_tokens([1, 2, 3])
|
||||||
|
seq = batch.abort_seq(sequence)
|
||||||
|
seq2 = batch.fliter_batch()[0]
|
||||||
|
|
||||||
|
assert batch.get_batch_size() == 1
|
||||||
|
assert seq.get_output_len() == 1
|
||||||
|
assert seq.output_token_id == [1]
|
||||||
|
assert seq2.get_output_len() == 1
|
||||||
|
assert seq2.output_token_id == [2]
|
||||||
|
|
||||||
batch.clear_batch()
|
batch.clear_batch()
|
||||||
assert batch.block_table == {}
|
assert batch.is_empty() == True
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_config_and_inference()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_config_and_inference():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_config_and_inferenceData()
|
test_config_and_inference()
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import spawn
|
||||||
|
|
||||||
|
|
||||||
|
def check_inference_engine():
|
||||||
|
model = transformers.LlamaForCausalLM(
|
||||||
|
transformers.LlamaConfig(
|
||||||
|
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
inference_config = InferenceConfig()
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
"介绍一下北京",
|
||||||
|
"介绍一下武汉",
|
||||||
|
]
|
||||||
|
|
||||||
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
outputs = inference_engine.generate(None)
|
||||||
|
|
||||||
|
for s1, s2 in zip(inputs, outputs):
|
||||||
|
assert s1 == s2
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_inference_engine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_inference_engine():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_inference_engine()
|
|
@ -1,12 +1,14 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama import LlamaConfig
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize, spawn
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
|
@ -64,7 +66,7 @@ def test_logical_blocks(test_config):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_cache_manager(test_config):
|
def check_cache_manager(test_config):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
|
|
||||||
assert test_config["max_batch_size"] > 1
|
assert test_config["max_batch_size"] > 1
|
||||||
|
@ -78,7 +80,7 @@ def test_cache_manager(test_config):
|
||||||
max_input_length = test_config["max_input_len"]
|
max_input_length = test_config["max_input_len"]
|
||||||
max_output_length = test_config["max_output_len"]
|
max_output_length = test_config["max_output_len"]
|
||||||
|
|
||||||
inference_config = InferenceConfig(model="", **test_config)
|
inference_config = InferenceConfig(**test_config)
|
||||||
model_config = LlamaConfig(
|
model_config = LlamaConfig(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
num_hidden_layers=num_layers,
|
num_hidden_layers=num_layers,
|
||||||
|
@ -147,6 +149,16 @@ def test_cache_manager(test_config):
|
||||||
assert cache_manager.get_num_available_blocks() == num_blocks
|
assert cache_manager.get_num_available_blocks() == num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_cache_manager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_cache_manager():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_logical_blocks()
|
test_logical_blocks()
|
||||||
test_cache_manager()
|
test_cache_manager()
|
||||||
|
|
Loading…
Reference in New Issue