From fab9b931d9e24c6e8ada8025cf8cf12719c3d2af Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 7 Dec 2023 14:34:01 +0800 Subject: [PATCH] [Inference]Add BatchInferState, Sequence and InferConfig (#5149) * add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct --- colossalai/inference/config.py | 7 - colossalai/inference/core/config.py | 54 ++++++ colossalai/inference/core/engine.py | 46 ++--- colossalai/inference/core/inference_struct.py | 169 ++++++++++++++++++ tests/test_infer/test_config_and_struct.py | 37 ++++ 5 files changed, 279 insertions(+), 34 deletions(-) delete mode 100644 colossalai/inference/config.py create mode 100644 colossalai/inference/core/config.py create mode 100644 colossalai/inference/core/inference_struct.py create mode 100644 tests/test_infer/test_config_and_struct.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py deleted file mode 100644 index d274beb14..000000000 --- a/colossalai/inference/config.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Our config consists of three parts: - 1. model_config: The configuration for the model, including `model name`, 'model path' and self-defined layer. - 2. parallel_config: The configuration for parallelize model, including `tp_size`,'pp_size', `world size`, `local rank`, `master port`, `master ip`. - 3. cache_config: Configuration for initialize and manage kv cache, including `block size`, `block num` -For the convenience of users, we provide a unified config api for that wrapped all the configs. One can easily construct a colossal_config by setting the needed configs. -""" diff --git a/colossalai/inference/core/config.py b/colossalai/inference/core/config.py new file mode 100644 index 000000000..6b44dd7af --- /dev/null +++ b/colossalai/inference/core/config.py @@ -0,0 +1,54 @@ +from typing import Optional, Union +from dataclasses import dataclass + +import torch +import torch.nn as nn + +@dataclass +class InferenceConfig: + """The inference configuration. + + Args: + model: Path or nn.Module of this model. + tokenizer: Path of the tokenizer to use. + tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Whether to trust remote code from huggingface. + max_batch_size: Maximum batch size. + max_output_len: Maximum output length. + max_input_len: Maximum input length. + block_size: The number of blocks in a logical block. + gpu_utilization_rate: Maximum GPU memory usage ratio. + dtype: The data type for weights and activations. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + max_seq_len: Maximum length of input sentence. + quant_mode: Quantization mode. + revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. + """ + + model: Union[str, nn.Module] + tokenizer: str = None + tokenizer_mode: str = "auto" + trust_remote_code: bool = False + max_batch_size: int = 8 + max_output_len: int = 256 + max_input_len: int = 256 + block_size: int = 16 + gpu_utilization_rate: float = 0.7 + dtype: Union[str, torch.dtype] = torch.float32 + tp_size: int = 1 + pp_size: int = 1 + max_seq_len: Optional[int] = None + quant_mode: Optional[str] = None + revision: Optional[str] = None + + def __post_init__(self): + self._verify_args() + + def _verify_args(self): + if self.gpu_utilization_rate > 1.0: + raise ValueError( + f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}." + ) + if self.tokenizer_mode not in ["auto", "slow"]: + raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}") diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index bf26b3ecb..7f78e9761 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,12 +1,14 @@ from logging import Logger from typing import Optional -from .request_handler import RequestHandler +from transformers import AutoConfig +from .config import InferenceConfig -class InferEngine: + +class InferenceEngine: """ - InferEngine is the core component for Inference. + InferenceEngine is the core component for Inference. It is responsible for launch the inference process, including: - Initialize model and distributed training environment(if needed) @@ -15,37 +17,27 @@ class InferEngine: - Log the generation process Args: - colossal_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. - model_config : The configuration for the model. - parallel_config: The configuration for parallelize model. - cache_config : Configuration for initialize and manage kv cache. - tokenizer (Tokenizer): The tokenizer to be used for inference. - use_logger (bool): Determine whether or not to log the generation process. + tokenizer: Path of the tokenizer to use. + inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. + verbose (bool): Determine whether or not to log the generation process. """ def __init__( self, - model_config, - cache_config, - parallel_config, - tokenizer, - use_logger: bool = False, - colossal_config: Optional["ColossalInferConfig"] = None, + tokenizer: str = None, + inference_config: Optional["InferenceConfig"] = None, + verbose: bool = False, ) -> None: - assert colossal_config or ( - model_config and cache_config and parallel_config - ), "Please provide colossal_config or model_config, cache_config, parallel_config" - if colossal_config: - model_config, cache_config, parallel_config = colossal_config - - self.model_config = model_config - self.cache_config = cache_config - self.parallel_config = parallel_config - self._verify_config() + assert inference_config, "Please provide inference_config." self._init_model() - self.request_handler = RequestHandler(cache_config) - if use_logger: + # cache_config may need to be modified later. + # self.request_handler = RequestHandler(cache_config) + self.tokenizer = tokenizer + self.hf_model_config = AutoConfig.from_pretrained( + self.model, trust_remote_code=self.trust_remote_code, revision=self.revision + ) + if verbose: self.logger = Logger() def _init_model(self): diff --git a/colossalai/inference/core/inference_struct.py b/colossalai/inference/core/inference_struct.py new file mode 100644 index 000000000..331f0308a --- /dev/null +++ b/colossalai/inference/core/inference_struct.py @@ -0,0 +1,169 @@ +import enum +from dataclasses import dataclass +from typing import Dict, List, Set + + +class RequsetStatus(enum.Enum): + """The status of Sentences""" + + WAITING = enum.auto() + RUNNING = enum.auto() + ABORTED = enum.auto() + OVERLENGTH = enum.auto() + COMPLETED = enum.auto() + LENGTH_CAPPED = enum.auto() + + @staticmethod + def is_finished(status: "RequsetStatus") -> bool: + return status in [ + RequsetStatus.OVERLENGTH, + RequsetStatus.COMPLETED, + RequsetStatus.LENGTH_CAPPED, + ] + + @staticmethod + def is_running(status: "RequsetStatus") -> bool: + return status == RequsetStatus.RUNNING + + @staticmethod + def is_waiting(status: "RequsetStatus") -> bool: + return status == RequsetStatus.WAITING + + +class Sequence: + """Store information of input sequence. + + Args: + request_id: The ID of input sequence. + prompt: The prompt of input sequence. + token_id: The tokens ID of input sequence. + block_size: The block size of input sequence. + sample_params: The sample_params of input sequence. + block_table_index: The index of input sequence in block_table. + """ + + def __init__( + self, + request_id: int, + prompt: str, + token_id: List[int], + block_size: int, + sample_params, # SampleParams needs to be imported later. + block_table_index: int, + ): + self.request_id = request_id + self.prompt = prompt + self.input_token_id = token_id + self.blokc_size = block_size + self.sample_params = sample_params + self.output_token_id = [] + self.status = RequsetStatus.WAITING + self.block_table_index = block_table_index + + def get_sentence_len(self) -> None: + """ + Get length of current sentence. + """ + return len(self.input_token_id) + len(self.output_token_id) + + def get_input_len(self) -> None: + """ + Get length of input sentence. + """ + return len(self.input_token_id) + + def get_output_len(self) -> None: + """ + Get output length of current sentence. + """ + return len(self.output_token_id) + + def check_finish(self) -> bool: + """ + Check whether inference is over. + """ + return RequsetStatus.is_finished(self.status) + + def __repr__(self) -> str: + return ( + f"Request ID(request_id={self.request_id}, " + f"prompt={self.prompt}, " + f"status={self.status.name}, " + f"sample_params={self.sample_params}, " + f"logical block number={len(self._logical_blocks)}" + ) + + +@dataclass +class BatchHandler: + """ + Information to be passed and used for a batch of sequences. + """ + + sequences_set: Set[Sequence] + block_table: Dict[int, int] + + @classmethod + def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler": + """ + Initializes inference batches by input sentence list. + + Args: + seqs (List[Sequence]): List of input sequence. + """ + sequences_set = set() + block_table = {} + for seq in seqs: + if seq in sequences_set: + print("The sequence is already in sequences_set.") + assert ( + seq.request_id in block_table + ), "The sequence has been added to sequences_set, but it has not been added to block_table." + continue + assert ( + seq.request_id not in block_table + ), "The sequence has not been added to sequences_set, but it is already in block_table." + + sequences_set.add(seq) + block_table[seq.request_id] = seq.block_table_index + + return cls(sequences_set=sequences_set, block_table=block_table) + + def clear_batch(self) -> None: + """ + Clear sequence set and block table. + """ + for seq in self.sequences_set: + if not seq.check_finish(): + seq.status = RequsetStatus.ABORTED + self.sequences_set.clear() + self.block_table.clear() + + def fliter_batch(self) -> None: + """ + Remove completed sentences from a batch. + """ + for seq in self.sequences_set: + if seq.check_finish(): + self.sequences_set.reomve(seq) + del self.block_table[seq.request_id] + + def add_seqs(self, seqs: List[Sequence]) -> None: + """ + Add new sequence to batch + + Args: + seqs (List[Sequence]): The list of new sequences. + """ + for seq in seqs: + if seq in self.sequences_set: + print("The sequence is already in sequences_set.") + assert ( + seq.request_id in self.block_table + ), "The sequence has been added to sequences_set, but it has not been added to block_table." + continue + assert ( + seq.request_id not in self.block_table + ), "The sequence has not been added to sequences_set, but it is already in block_table." + self.sequences_set.add(seq) + self.block_table[seq.request_id] = seq.block_table_index diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py new file mode 100644 index 000000000..580396e51 --- /dev/null +++ b/tests/test_infer/test_config_and_struct.py @@ -0,0 +1,37 @@ +from colossalai.inference.core.config import InferenceConfig +from colossalai.inference.core.inference_struct import BatchHandler, Sequence + + +def test_config_and_struct(): + InferenceConfig("/llama") + sequence = Sequence( + request_id=1, + prompt="abc", + token_id=[1, 2, 3], + block_size=16, + sample_params=None, + block_table_index=1, + ) + + sequence2 = Sequence( + request_id=2, + prompt="bcd", + token_id=[4, 5, 6], + block_size=16, + sample_params=None, + block_table_index=2, + ) + + assert sequence.get_sentence_len() == 3 + assert sequence.get_input_len() == 3 + assert sequence.get_output_len() == 0 + assert sequence.check_finish() == False + + batch = BatchHandler.init_batch([sequence]) + batch.fliter_batch() + batch.add_seqs([sequence2]) + batch.clear_batch() + + +if __name__ == "__main__": + test_config_and_struct()