[Inference]Update inference config and fix test (#5178)

* unify the config setting

* fix test

* fix import

* fix test

* fix

* fix

* add logger

* revise log info

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
pull/5258/head
Jianghai 2023-12-12 17:22:41 +08:00 committed by FrankLeeeee
parent 3de2e62299
commit 93aeacca34
9 changed files with 61 additions and 25 deletions

View File

@ -1,9 +1,14 @@
import logging
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
GibiByte = 1024**3
logger = logging.Logger(__name__)
@dataclass
class InferenceConfig:
@ -18,7 +23,6 @@ class InferenceConfig:
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
gpu_utilization_rate: Maximum GPU memory usage ratio.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
@ -27,13 +31,15 @@ class InferenceConfig:
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
"""
model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = 8
max_batch_size: int = None
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
@ -43,10 +49,34 @@ class InferenceConfig:
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
def _init_batch_size(self):
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return
device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8
if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)
def __post_init__(self):
self._init_batch_size()
self._verify_args()
def _verify_args(self):

View File

@ -3,7 +3,7 @@ from typing import Optional
from transformers import AutoConfig
from .config import InferenceConfig
from colossalai.inference.config import InferenceConfig
class InferenceEngine:

View File

@ -3,7 +3,7 @@ from typing import List, Tuple
import torch
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

View File

@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o
## Structures
### Overview
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b
The main design will be released later on.
## Roadmap
- [] design of structures
- [] Core components

View File

@ -1,3 +0,0 @@
"""
The abstraction of request and sequence are defined here.
"""

View File

@ -2,6 +2,10 @@ import enum
from dataclasses import dataclass
from typing import Dict, List, Set
"""
The abstraction of request and sequence are defined here.
"""
class RequsetStatus(enum.Enum):
"""The status of Sentences"""
@ -95,16 +99,16 @@ class Sequence:
@dataclass
class BatchHandler:
class BatchInfo:
"""
Information to be passed and used for a batch of sequences.
"""
sequences_set: Set[Sequence]
block_table: Dict[int, int]
block_table: Dict[int, int] = None
@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
"""
Initializes inference batches by input sentence list.
@ -115,13 +119,13 @@ class BatchHandler:
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table
seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set.add(seq)
@ -143,9 +147,9 @@ class BatchHandler:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set:
for seq in self.sequences_set.copy():
if seq.check_finish():
self.sequences_set.reomve(seq)
self.sequences_set.remove(seq)
del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None:

View File

@ -1,9 +1,10 @@
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
def test_config_and_struct():
InferenceConfig("/llama")
def test_config_and_inferenceData():
config = InferenceConfig("/llama")
assert config.max_batch_size
sequence = Sequence(
request_id=1,
prompt="abc",
@ -27,11 +28,16 @@ def test_config_and_struct():
assert sequence.get_output_len() == 0
assert sequence.check_finish() == False
batch = BatchHandler.init_batch([sequence])
batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.clear_batch()
assert batch.block_table == {}
if __name__ == "__main__":
test_config_and_struct()
test_config_and_inferenceData()

View File

@ -3,7 +3,7 @@ import random
import torch
from transformers.models.llama import LlamaConfig
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize