[Inference]Update inference config and fix test (#5178)

* unify the config setting

* fix test

* fix import

* fix test

* fix

* fix

* add logger

* revise log info

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
pull/5258/head
Jianghai 2023-12-12 17:22:41 +08:00 committed by FrankLeeeee
parent 3de2e62299
commit 93aeacca34
9 changed files with 61 additions and 25 deletions

View File

@ -1,9 +1,14 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
GibiByte = 1024**3
logger = logging.Logger(__name__)
@dataclass @dataclass
class InferenceConfig: class InferenceConfig:
@ -18,7 +23,6 @@ class InferenceConfig:
max_output_len: Maximum output length. max_output_len: Maximum output length.
max_input_len: Maximum input length. max_input_len: Maximum input length.
block_size: The number of blocks in a logical block. block_size: The number of blocks in a logical block.
gpu_utilization_rate: Maximum GPU memory usage ratio.
dtype: The data type for weights and activations. dtype: The data type for weights and activations.
tp_size: Tensor parallel size. tp_size: Tensor parallel size.
pp_size: Pipeline parallel size. pp_size: Pipeline parallel size.
@ -27,13 +31,15 @@ class InferenceConfig:
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
beam_width: The maximum beam width used to initialize KV Cache. beam_width: The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
""" """
model: Union[str, nn.Module] model: Union[str, nn.Module]
tokenizer: str = None tokenizer: str = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
trust_remote_code: bool = False trust_remote_code: bool = False
max_batch_size: int = 8 max_batch_size: int = None
max_output_len: int = 256 max_output_len: int = 256
max_input_len: int = 256 max_input_len: int = 256
block_size: int = 16 block_size: int = 16
@ -43,10 +49,34 @@ class InferenceConfig:
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None quant_mode: Optional[str] = None
revision: Optional[str] = None revision: Optional[str] = None
# TODO: beam search is not support for now
beam_width: int = 1 beam_width: int = 1
# TODO: beam search is not support for now
prefill_ratio: Optional[float] = 1.2
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
def _init_batch_size(self):
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return
device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8
if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)
def __post_init__(self): def __post_init__(self):
self._init_batch_size()
self._verify_args() self._verify_args()
def _verify_args(self): def _verify_args(self):

View File

@ -3,7 +3,7 @@ from typing import Optional
from transformers import AutoConfig from transformers import AutoConfig
from .config import InferenceConfig from colossalai.inference.config import InferenceConfig
class InferenceEngine: class InferenceEngine:

View File

@ -3,7 +3,7 @@ from typing import List, Tuple
import torch import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.core.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device

View File

@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o
## Structures ## Structures
### Overview ### Overview
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b The main design will be released later on.
## Roadmap ## Roadmap
- [] design of structures - [] design of structures
- [] Core components - [] Core components

View File

@ -1,3 +0,0 @@
"""
The abstraction of request and sequence are defined here.
"""

View File

@ -2,6 +2,10 @@ import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Set from typing import Dict, List, Set
"""
The abstraction of request and sequence are defined here.
"""
class RequsetStatus(enum.Enum): class RequsetStatus(enum.Enum):
"""The status of Sentences""" """The status of Sentences"""
@ -95,16 +99,16 @@ class Sequence:
@dataclass @dataclass
class BatchHandler: class BatchInfo:
""" """
Information to be passed and used for a batch of sequences. Information to be passed and used for a batch of sequences.
""" """
sequences_set: Set[Sequence] sequences_set: Set[Sequence]
block_table: Dict[int, int] block_table: Dict[int, int] = None
@classmethod @classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler": def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
""" """
Initializes inference batches by input sentence list. Initializes inference batches by input sentence list.
@ -115,13 +119,13 @@ class BatchHandler:
block_table = {} block_table = {}
for seq in seqs: for seq in seqs:
if seq in sequences_set: if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert ( assert (
seq.request_id in block_table seq.request_id in block_table.keys()
), "The sequence has been added to sequences_set, but it has not been added to block_table." ), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue continue
assert ( assert (
seq.request_id not in block_table seq.request_id not in block_table.keys()
), "The sequence has not been added to sequences_set, but it is already in block_table." ), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set.add(seq) sequences_set.add(seq)
@ -143,9 +147,9 @@ class BatchHandler:
""" """
Remove completed sentences from a batch. Remove completed sentences from a batch.
""" """
for seq in self.sequences_set: for seq in self.sequences_set.copy():
if seq.check_finish(): if seq.check_finish():
self.sequences_set.reomve(seq) self.sequences_set.remove(seq)
del self.block_table[seq.request_id] del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None: def add_seqs(self, seqs: List[Sequence]) -> None:

View File

@ -1,9 +1,10 @@
from colossalai.inference.core.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
def test_config_and_struct(): def test_config_and_inferenceData():
InferenceConfig("/llama") config = InferenceConfig("/llama")
assert config.max_batch_size
sequence = Sequence( sequence = Sequence(
request_id=1, request_id=1,
prompt="abc", prompt="abc",
@ -27,11 +28,16 @@ def test_config_and_struct():
assert sequence.get_output_len() == 0 assert sequence.get_output_len() == 0
assert sequence.check_finish() == False assert sequence.check_finish() == False
batch = BatchHandler.init_batch([sequence]) batch = BatchInfo.init_batch([sequence])
assert batch.block_table[sequence.request_id] == sequence.block_table_index
sequence.status = RequsetStatus.COMPLETED
batch.fliter_batch() batch.fliter_batch()
assert batch.block_table == {}
batch.add_seqs([sequence2]) batch.add_seqs([sequence2])
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
batch.clear_batch() batch.clear_batch()
assert batch.block_table == {}
if __name__ == "__main__": if __name__ == "__main__":
test_config_and_struct() test_config_and_inferenceData()

View File

@ -3,7 +3,7 @@ import random
import torch import torch
from transformers.models.llama import LlamaConfig from transformers.models.llama import LlamaConfig
from colossalai.inference.core.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize from colossalai.testing import parameterize