From 93aeacca342ab03732362dbb9096ab1265f4a8b3 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 12 Dec 2023 17:22:41 +0800 Subject: [PATCH] [Inference]Update inference config and fix test (#5178) * unify the config setting * fix test * fix import * fix test * fix * fix * add logger * revise log info --------- Co-authored-by: CjhHa1 --- colossalai/inference/{core => }/config.py | 36 +++++++++++++++++-- colossalai/inference/core/cache_manager.py | 0 colossalai/inference/core/engine.py | 2 +- .../inference/kv_cache/kvcache_manager.py | 2 +- colossalai/inference/readme.md | 3 +- colossalai/inference/sequence.py | 3 -- .../{core/inference_struct.py => struct.py} | 20 ++++++----- tests/test_infer/test_config_and_struct.py | 18 ++++++---- tests/test_infer/test_kvcache_manager.py | 2 +- 9 files changed, 61 insertions(+), 25 deletions(-) rename colossalai/inference/{core => }/config.py (61%) delete mode 100644 colossalai/inference/core/cache_manager.py delete mode 100644 colossalai/inference/sequence.py rename colossalai/inference/{core/inference_struct.py => struct.py} (92%) diff --git a/colossalai/inference/core/config.py b/colossalai/inference/config.py similarity index 61% rename from colossalai/inference/core/config.py rename to colossalai/inference/config.py index 43d0b2bb2..ea06335b7 100644 --- a/colossalai/inference/core/config.py +++ b/colossalai/inference/config.py @@ -1,9 +1,14 @@ +import logging from dataclasses import dataclass from typing import Optional, Union import torch import torch.nn as nn +GibiByte = 1024**3 + +logger = logging.Logger(__name__) + @dataclass class InferenceConfig: @@ -18,7 +23,6 @@ class InferenceConfig: max_output_len: Maximum output length. max_input_len: Maximum input length. block_size: The number of blocks in a logical block. - gpu_utilization_rate: Maximum GPU memory usage ratio. dtype: The data type for weights and activations. tp_size: Tensor parallel size. pp_size: Pipeline parallel size. @@ -27,13 +31,15 @@ class InferenceConfig: revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. beam_width: The maximum beam width used to initialize KV Cache. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. + prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill + when the actual value exceeds this ratio. """ model: Union[str, nn.Module] tokenizer: str = None tokenizer_mode: str = "auto" trust_remote_code: bool = False - max_batch_size: int = 8 + max_batch_size: int = None max_output_len: int = 256 max_input_len: int = 256 block_size: int = 16 @@ -43,10 +49,34 @@ class InferenceConfig: max_seq_len: Optional[int] = None quant_mode: Optional[str] = None revision: Optional[str] = None - # TODO: beam search is not support for now beam_width: int = 1 + # TODO: beam search is not support for now + prefill_ratio: Optional[float] = 1.2 + # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + + def _init_batch_size(self): + """ + MAX_BATCH_SIZE is set to acurately utilize the memory of gpu. + We take a simple method to determine it by GPU memory size, user can still set it manually. + """ + if self.max_batch_size is not None: + # already set by user + return + + device = torch.device("cuda") + total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte + self.max_batch_size = 8 + + if 40 < total_mem <= 60: + self.max_batch_size = 16 + elif 60 < total_mem <= 80: + self.max_batch_size = 32 + logger.info( + f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user." + ) def __post_init__(self): + self._init_batch_size() self._verify_args() def _verify_args(self): diff --git a/colossalai/inference/core/cache_manager.py b/colossalai/inference/core/cache_manager.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 7f78e9761..232bfb188 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -3,7 +3,7 @@ from typing import Optional from transformers import AutoConfig -from .config import InferenceConfig +from colossalai.inference.config import InferenceConfig class InferenceEngine: diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 8bf7af61c..493613d68 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -3,7 +3,7 @@ from typing import List, Tuple import torch from transformers.configuration_utils import PretrainedConfig -from colossalai.inference.core.config import InferenceConfig +from colossalai.inference.config import InferenceConfig from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device diff --git a/colossalai/inference/readme.md b/colossalai/inference/readme.md index 301b546ff..e87e46f05 100644 --- a/colossalai/inference/readme.md +++ b/colossalai/inference/readme.md @@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o ## Structures ### Overview -https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b - +The main design will be released later on. ## Roadmap - [] design of structures - [] Core components diff --git a/colossalai/inference/sequence.py b/colossalai/inference/sequence.py deleted file mode 100644 index 74ec631f4..000000000 --- a/colossalai/inference/sequence.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -The abstraction of request and sequence are defined here. -""" diff --git a/colossalai/inference/core/inference_struct.py b/colossalai/inference/struct.py similarity index 92% rename from colossalai/inference/core/inference_struct.py rename to colossalai/inference/struct.py index 331f0308a..a5201d787 100644 --- a/colossalai/inference/core/inference_struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,10 @@ import enum from dataclasses import dataclass from typing import Dict, List, Set +""" +The abstraction of request and sequence are defined here. +""" + class RequsetStatus(enum.Enum): """The status of Sentences""" @@ -95,16 +99,16 @@ class Sequence: @dataclass -class BatchHandler: +class BatchInfo: """ Information to be passed and used for a batch of sequences. """ sequences_set: Set[Sequence] - block_table: Dict[int, int] + block_table: Dict[int, int] = None @classmethod - def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler": + def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo": """ Initializes inference batches by input sentence list. @@ -115,13 +119,13 @@ class BatchHandler: block_table = {} for seq in seqs: if seq in sequences_set: - print("The sequence is already in sequences_set.") assert ( - seq.request_id in block_table + seq.request_id in block_table.keys() ), "The sequence has been added to sequences_set, but it has not been added to block_table." continue + assert ( - seq.request_id not in block_table + seq.request_id not in block_table.keys() ), "The sequence has not been added to sequences_set, but it is already in block_table." sequences_set.add(seq) @@ -143,9 +147,9 @@ class BatchHandler: """ Remove completed sentences from a batch. """ - for seq in self.sequences_set: + for seq in self.sequences_set.copy(): if seq.check_finish(): - self.sequences_set.reomve(seq) + self.sequences_set.remove(seq) del self.block_table[seq.request_id] def add_seqs(self, seqs: List[Sequence]) -> None: diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 580396e51..329165025 100644 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -1,9 +1,10 @@ -from colossalai.inference.core.config import InferenceConfig -from colossalai.inference.core.inference_struct import BatchHandler, Sequence +from colossalai.inference.config import InferenceConfig +from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence -def test_config_and_struct(): - InferenceConfig("/llama") +def test_config_and_inferenceData(): + config = InferenceConfig("/llama") + assert config.max_batch_size sequence = Sequence( request_id=1, prompt="abc", @@ -27,11 +28,16 @@ def test_config_and_struct(): assert sequence.get_output_len() == 0 assert sequence.check_finish() == False - batch = BatchHandler.init_batch([sequence]) + batch = BatchInfo.init_batch([sequence]) + assert batch.block_table[sequence.request_id] == sequence.block_table_index + sequence.status = RequsetStatus.COMPLETED batch.fliter_batch() + assert batch.block_table == {} batch.add_seqs([sequence2]) + assert batch.block_table[sequence2.request_id] == sequence2.block_table_index batch.clear_batch() + assert batch.block_table == {} if __name__ == "__main__": - test_config_and_struct() + test_config_and_inferenceData() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index ee37f3ce1..5187727f1 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -3,7 +3,7 @@ import random import torch from transformers.models.llama import LlamaConfig -from colossalai.inference.core.config import InferenceConfig +from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.logging import disable_existing_loggers from colossalai.testing import parameterize