mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Add BatchInferState, Sequence and InferConfig (#5149)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and structpull/5258/head
parent
2bb92243d4
commit
fab9b931d9
|
@ -1,7 +0,0 @@
|
||||||
"""
|
|
||||||
Our config consists of three parts:
|
|
||||||
1. model_config: The configuration for the model, including `model name`, 'model path' and self-defined layer.
|
|
||||||
2. parallel_config: The configuration for parallelize model, including `tp_size`,'pp_size', `world size`, `local rank`, `master port`, `master ip`.
|
|
||||||
3. cache_config: Configuration for initialize and manage kv cache, including `block size`, `block num`
|
|
||||||
For the convenience of users, we provide a unified config api for that wrapped all the configs. One can easily construct a colossal_config by setting the needed configs.
|
|
||||||
"""
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceConfig:
|
||||||
|
"""The inference configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Path or nn.Module of this model.
|
||||||
|
tokenizer: Path of the tokenizer to use.
|
||||||
|
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
|
||||||
|
trust_remote_code: Whether to trust remote code from huggingface.
|
||||||
|
max_batch_size: Maximum batch size.
|
||||||
|
max_output_len: Maximum output length.
|
||||||
|
max_input_len: Maximum input length.
|
||||||
|
block_size: The number of blocks in a logical block.
|
||||||
|
gpu_utilization_rate: Maximum GPU memory usage ratio.
|
||||||
|
dtype: The data type for weights and activations.
|
||||||
|
tp_size: Tensor parallel size.
|
||||||
|
pp_size: Pipeline parallel size.
|
||||||
|
max_seq_len: Maximum length of input sentence.
|
||||||
|
quant_mode: Quantization mode.
|
||||||
|
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: Union[str, nn.Module]
|
||||||
|
tokenizer: str = None
|
||||||
|
tokenizer_mode: str = "auto"
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
max_batch_size: int = 8
|
||||||
|
max_output_len: int = 256
|
||||||
|
max_input_len: int = 256
|
||||||
|
block_size: int = 16
|
||||||
|
gpu_utilization_rate: float = 0.7
|
||||||
|
dtype: Union[str, torch.dtype] = torch.float32
|
||||||
|
tp_size: int = 1
|
||||||
|
pp_size: int = 1
|
||||||
|
max_seq_len: Optional[int] = None
|
||||||
|
quant_mode: Optional[str] = None
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self):
|
||||||
|
if self.gpu_utilization_rate > 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
|
||||||
|
)
|
||||||
|
if self.tokenizer_mode not in ["auto", "slow"]:
|
||||||
|
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
|
|
@ -1,12 +1,14 @@
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .request_handler import RequestHandler
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from .config import InferenceConfig
|
||||||
|
|
||||||
|
|
||||||
class InferEngine:
|
class InferenceEngine:
|
||||||
"""
|
"""
|
||||||
InferEngine is the core component for Inference.
|
InferenceEngine is the core component for Inference.
|
||||||
|
|
||||||
It is responsible for launch the inference process, including:
|
It is responsible for launch the inference process, including:
|
||||||
- Initialize model and distributed training environment(if needed)
|
- Initialize model and distributed training environment(if needed)
|
||||||
|
@ -15,37 +17,27 @@ class InferEngine:
|
||||||
- Log the generation process
|
- Log the generation process
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
colossal_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
|
tokenizer: Path of the tokenizer to use.
|
||||||
model_config : The configuration for the model.
|
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
|
||||||
parallel_config: The configuration for parallelize model.
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
cache_config : Configuration for initialize and manage kv cache.
|
|
||||||
tokenizer (Tokenizer): The tokenizer to be used for inference.
|
|
||||||
use_logger (bool): Determine whether or not to log the generation process.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config,
|
tokenizer: str = None,
|
||||||
cache_config,
|
inference_config: Optional["InferenceConfig"] = None,
|
||||||
parallel_config,
|
verbose: bool = False,
|
||||||
tokenizer,
|
|
||||||
use_logger: bool = False,
|
|
||||||
colossal_config: Optional["ColossalInferConfig"] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
assert colossal_config or (
|
assert inference_config, "Please provide inference_config."
|
||||||
model_config and cache_config and parallel_config
|
|
||||||
), "Please provide colossal_config or model_config, cache_config, parallel_config"
|
|
||||||
if colossal_config:
|
|
||||||
model_config, cache_config, parallel_config = colossal_config
|
|
||||||
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self._verify_config()
|
|
||||||
|
|
||||||
self._init_model()
|
self._init_model()
|
||||||
self.request_handler = RequestHandler(cache_config)
|
# cache_config may need to be modified later.
|
||||||
if use_logger:
|
# self.request_handler = RequestHandler(cache_config)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.hf_model_config = AutoConfig.from_pretrained(
|
||||||
|
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
self.logger = Logger()
|
self.logger = Logger()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
import enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
|
|
||||||
|
class RequsetStatus(enum.Enum):
|
||||||
|
"""The status of Sentences"""
|
||||||
|
|
||||||
|
WAITING = enum.auto()
|
||||||
|
RUNNING = enum.auto()
|
||||||
|
ABORTED = enum.auto()
|
||||||
|
OVERLENGTH = enum.auto()
|
||||||
|
COMPLETED = enum.auto()
|
||||||
|
LENGTH_CAPPED = enum.auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_finished(status: "RequsetStatus") -> bool:
|
||||||
|
return status in [
|
||||||
|
RequsetStatus.OVERLENGTH,
|
||||||
|
RequsetStatus.COMPLETED,
|
||||||
|
RequsetStatus.LENGTH_CAPPED,
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_running(status: "RequsetStatus") -> bool:
|
||||||
|
return status == RequsetStatus.RUNNING
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_waiting(status: "RequsetStatus") -> bool:
|
||||||
|
return status == RequsetStatus.WAITING
|
||||||
|
|
||||||
|
|
||||||
|
class Sequence:
|
||||||
|
"""Store information of input sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The ID of input sequence.
|
||||||
|
prompt: The prompt of input sequence.
|
||||||
|
token_id: The tokens ID of input sequence.
|
||||||
|
block_size: The block size of input sequence.
|
||||||
|
sample_params: The sample_params of input sequence.
|
||||||
|
block_table_index: The index of input sequence in block_table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: int,
|
||||||
|
prompt: str,
|
||||||
|
token_id: List[int],
|
||||||
|
block_size: int,
|
||||||
|
sample_params, # SampleParams needs to be imported later.
|
||||||
|
block_table_index: int,
|
||||||
|
):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.prompt = prompt
|
||||||
|
self.input_token_id = token_id
|
||||||
|
self.blokc_size = block_size
|
||||||
|
self.sample_params = sample_params
|
||||||
|
self.output_token_id = []
|
||||||
|
self.status = RequsetStatus.WAITING
|
||||||
|
self.block_table_index = block_table_index
|
||||||
|
|
||||||
|
def get_sentence_len(self) -> None:
|
||||||
|
"""
|
||||||
|
Get length of current sentence.
|
||||||
|
"""
|
||||||
|
return len(self.input_token_id) + len(self.output_token_id)
|
||||||
|
|
||||||
|
def get_input_len(self) -> None:
|
||||||
|
"""
|
||||||
|
Get length of input sentence.
|
||||||
|
"""
|
||||||
|
return len(self.input_token_id)
|
||||||
|
|
||||||
|
def get_output_len(self) -> None:
|
||||||
|
"""
|
||||||
|
Get output length of current sentence.
|
||||||
|
"""
|
||||||
|
return len(self.output_token_id)
|
||||||
|
|
||||||
|
def check_finish(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether inference is over.
|
||||||
|
"""
|
||||||
|
return RequsetStatus.is_finished(self.status)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Request ID(request_id={self.request_id}, "
|
||||||
|
f"prompt={self.prompt}, "
|
||||||
|
f"status={self.status.name}, "
|
||||||
|
f"sample_params={self.sample_params}, "
|
||||||
|
f"logical block number={len(self._logical_blocks)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchHandler:
|
||||||
|
"""
|
||||||
|
Information to be passed and used for a batch of sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequences_set: Set[Sequence]
|
||||||
|
block_table: Dict[int, int]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
|
||||||
|
"""
|
||||||
|
Initializes inference batches by input sentence list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqs (List[Sequence]): List of input sequence.
|
||||||
|
"""
|
||||||
|
sequences_set = set()
|
||||||
|
block_table = {}
|
||||||
|
for seq in seqs:
|
||||||
|
if seq in sequences_set:
|
||||||
|
print("The sequence is already in sequences_set.")
|
||||||
|
assert (
|
||||||
|
seq.request_id in block_table
|
||||||
|
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
seq.request_id not in block_table
|
||||||
|
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||||
|
|
||||||
|
sequences_set.add(seq)
|
||||||
|
block_table[seq.request_id] = seq.block_table_index
|
||||||
|
|
||||||
|
return cls(sequences_set=sequences_set, block_table=block_table)
|
||||||
|
|
||||||
|
def clear_batch(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear sequence set and block table.
|
||||||
|
"""
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
if not seq.check_finish():
|
||||||
|
seq.status = RequsetStatus.ABORTED
|
||||||
|
self.sequences_set.clear()
|
||||||
|
self.block_table.clear()
|
||||||
|
|
||||||
|
def fliter_batch(self) -> None:
|
||||||
|
"""
|
||||||
|
Remove completed sentences from a batch.
|
||||||
|
"""
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
if seq.check_finish():
|
||||||
|
self.sequences_set.reomve(seq)
|
||||||
|
del self.block_table[seq.request_id]
|
||||||
|
|
||||||
|
def add_seqs(self, seqs: List[Sequence]) -> None:
|
||||||
|
"""
|
||||||
|
Add new sequence to batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqs (List[Sequence]): The list of new sequences.
|
||||||
|
"""
|
||||||
|
for seq in seqs:
|
||||||
|
if seq in self.sequences_set:
|
||||||
|
print("The sequence is already in sequences_set.")
|
||||||
|
assert (
|
||||||
|
seq.request_id in self.block_table
|
||||||
|
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
seq.request_id not in self.block_table
|
||||||
|
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||||
|
self.sequences_set.add(seq)
|
||||||
|
self.block_table[seq.request_id] = seq.block_table_index
|
|
@ -0,0 +1,37 @@
|
||||||
|
from colossalai.inference.core.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_and_struct():
|
||||||
|
InferenceConfig("/llama")
|
||||||
|
sequence = Sequence(
|
||||||
|
request_id=1,
|
||||||
|
prompt="abc",
|
||||||
|
token_id=[1, 2, 3],
|
||||||
|
block_size=16,
|
||||||
|
sample_params=None,
|
||||||
|
block_table_index=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence2 = Sequence(
|
||||||
|
request_id=2,
|
||||||
|
prompt="bcd",
|
||||||
|
token_id=[4, 5, 6],
|
||||||
|
block_size=16,
|
||||||
|
sample_params=None,
|
||||||
|
block_table_index=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sequence.get_sentence_len() == 3
|
||||||
|
assert sequence.get_input_len() == 3
|
||||||
|
assert sequence.get_output_len() == 0
|
||||||
|
assert sequence.check_finish() == False
|
||||||
|
|
||||||
|
batch = BatchHandler.init_batch([sequence])
|
||||||
|
batch.fliter_batch()
|
||||||
|
batch.add_seqs([sequence2])
|
||||||
|
batch.clear_batch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_config_and_struct()
|
Loading…
Reference in New Issue