mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Update inference config and fix test (#5178)
* unify the config setting * fix test * fix import * fix test * fix * fix * add logger * revise log info --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>pull/5258/head
parent
3de2e62299
commit
93aeacca34
|
@ -1,9 +1,14 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
|
@ -18,7 +23,6 @@ class InferenceConfig:
|
|||
max_output_len: Maximum output length.
|
||||
max_input_len: Maximum input length.
|
||||
block_size: The number of blocks in a logical block.
|
||||
gpu_utilization_rate: Maximum GPU memory usage ratio.
|
||||
dtype: The data type for weights and activations.
|
||||
tp_size: Tensor parallel size.
|
||||
pp_size: Pipeline parallel size.
|
||||
|
@ -27,13 +31,15 @@ class InferenceConfig:
|
|||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
beam_width: The maximum beam width used to initialize KV Cache.
|
||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||
when the actual value exceeds this ratio.
|
||||
"""
|
||||
|
||||
model: Union[str, nn.Module]
|
||||
tokenizer: str = None
|
||||
tokenizer_mode: str = "auto"
|
||||
trust_remote_code: bool = False
|
||||
max_batch_size: int = 8
|
||||
max_batch_size: int = None
|
||||
max_output_len: int = 256
|
||||
max_input_len: int = 256
|
||||
block_size: int = 16
|
||||
|
@ -43,10 +49,34 @@ class InferenceConfig:
|
|||
max_seq_len: Optional[int] = None
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
# TODO: beam search is not support for now
|
||||
beam_width: int = 1
|
||||
# TODO: beam search is not support for now
|
||||
prefill_ratio: Optional[float] = 1.2
|
||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
|
||||
def _init_batch_size(self):
|
||||
"""
|
||||
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
|
||||
We take a simple method to determine it by GPU memory size, user can still set it manually.
|
||||
"""
|
||||
if self.max_batch_size is not None:
|
||||
# already set by user
|
||||
return
|
||||
|
||||
device = torch.device("cuda")
|
||||
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
|
||||
self.max_batch_size = 8
|
||||
|
||||
if 40 < total_mem <= 60:
|
||||
self.max_batch_size = 16
|
||||
elif 60 < total_mem <= 80:
|
||||
self.max_batch_size = 32
|
||||
logger.info(
|
||||
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self):
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
|
||||
from transformers import AutoConfig
|
||||
|
||||
from .config import InferenceConfig
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Tuple
|
|||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.core.config import InferenceConfig
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o
|
|||
|
||||
## Structures
|
||||
### Overview
|
||||
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b
|
||||
|
||||
The main design will be released later on.
|
||||
## Roadmap
|
||||
- [] design of structures
|
||||
- [] Core components
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
"""
|
||||
The abstraction of request and sequence are defined here.
|
||||
"""
|
|
@ -2,6 +2,10 @@ import enum
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Set
|
||||
|
||||
"""
|
||||
The abstraction of request and sequence are defined here.
|
||||
"""
|
||||
|
||||
|
||||
class RequsetStatus(enum.Enum):
|
||||
"""The status of Sentences"""
|
||||
|
@ -95,16 +99,16 @@ class Sequence:
|
|||
|
||||
|
||||
@dataclass
|
||||
class BatchHandler:
|
||||
class BatchInfo:
|
||||
"""
|
||||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
sequences_set: Set[Sequence]
|
||||
block_table: Dict[int, int]
|
||||
block_table: Dict[int, int] = None
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
|
||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
|
@ -115,13 +119,13 @@ class BatchHandler:
|
|||
block_table = {}
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
print("The sequence is already in sequences_set.")
|
||||
assert (
|
||||
seq.request_id in block_table
|
||||
seq.request_id in block_table.keys()
|
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||
continue
|
||||
|
||||
assert (
|
||||
seq.request_id not in block_table
|
||||
seq.request_id not in block_table.keys()
|
||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||
|
||||
sequences_set.add(seq)
|
||||
|
@ -143,9 +147,9 @@ class BatchHandler:
|
|||
"""
|
||||
Remove completed sentences from a batch.
|
||||
"""
|
||||
for seq in self.sequences_set:
|
||||
for seq in self.sequences_set.copy():
|
||||
if seq.check_finish():
|
||||
self.sequences_set.reomve(seq)
|
||||
self.sequences_set.remove(seq)
|
||||
del self.block_table[seq.request_id]
|
||||
|
||||
def add_seqs(self, seqs: List[Sequence]) -> None:
|
|
@ -1,9 +1,10 @@
|
|||
from colossalai.inference.core.config import InferenceConfig
|
||||
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
|
||||
|
||||
|
||||
def test_config_and_struct():
|
||||
InferenceConfig("/llama")
|
||||
def test_config_and_inferenceData():
|
||||
config = InferenceConfig("/llama")
|
||||
assert config.max_batch_size
|
||||
sequence = Sequence(
|
||||
request_id=1,
|
||||
prompt="abc",
|
||||
|
@ -27,11 +28,16 @@ def test_config_and_struct():
|
|||
assert sequence.get_output_len() == 0
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchHandler.init_batch([sequence])
|
||||
batch = BatchInfo.init_batch([sequence])
|
||||
assert batch.block_table[sequence.request_id] == sequence.block_table_index
|
||||
sequence.status = RequsetStatus.COMPLETED
|
||||
batch.fliter_batch()
|
||||
assert batch.block_table == {}
|
||||
batch.add_seqs([sequence2])
|
||||
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
|
||||
batch.clear_batch()
|
||||
assert batch.block_table == {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_config_and_struct()
|
||||
test_config_and_inferenceData()
|
||||
|
|
|
@ -3,7 +3,7 @@ import random
|
|||
import torch
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
from colossalai.inference.core.config import InferenceConfig
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize
|
||||
|
|
Loading…
Reference in New Issue