mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Update inference config and fix test (#5178)
* unify the config setting * fix test * fix import * fix test * fix * fix * add logger * revise log info --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>pull/5258/head
parent
3de2e62299
commit
93aeacca34
|
@ -1,9 +1,14 @@
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
GibiByte = 1024**3
|
||||||
|
|
||||||
|
logger = logging.Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceConfig:
|
class InferenceConfig:
|
||||||
|
@ -18,7 +23,6 @@ class InferenceConfig:
|
||||||
max_output_len: Maximum output length.
|
max_output_len: Maximum output length.
|
||||||
max_input_len: Maximum input length.
|
max_input_len: Maximum input length.
|
||||||
block_size: The number of blocks in a logical block.
|
block_size: The number of blocks in a logical block.
|
||||||
gpu_utilization_rate: Maximum GPU memory usage ratio.
|
|
||||||
dtype: The data type for weights and activations.
|
dtype: The data type for weights and activations.
|
||||||
tp_size: Tensor parallel size.
|
tp_size: Tensor parallel size.
|
||||||
pp_size: Pipeline parallel size.
|
pp_size: Pipeline parallel size.
|
||||||
|
@ -27,13 +31,15 @@ class InferenceConfig:
|
||||||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||||
beam_width: The maximum beam width used to initialize KV Cache.
|
beam_width: The maximum beam width used to initialize KV Cache.
|
||||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||||
|
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||||
|
when the actual value exceeds this ratio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: Union[str, nn.Module]
|
model: Union[str, nn.Module]
|
||||||
tokenizer: str = None
|
tokenizer: str = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
max_batch_size: int = 8
|
max_batch_size: int = None
|
||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
max_input_len: int = 256
|
max_input_len: int = 256
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
|
@ -43,10 +49,34 @@ class InferenceConfig:
|
||||||
max_seq_len: Optional[int] = None
|
max_seq_len: Optional[int] = None
|
||||||
quant_mode: Optional[str] = None
|
quant_mode: Optional[str] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
# TODO: beam search is not support for now
|
|
||||||
beam_width: int = 1
|
beam_width: int = 1
|
||||||
|
# TODO: beam search is not support for now
|
||||||
|
prefill_ratio: Optional[float] = 1.2
|
||||||
|
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
|
|
||||||
|
def _init_batch_size(self):
|
||||||
|
"""
|
||||||
|
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
|
||||||
|
We take a simple method to determine it by GPU memory size, user can still set it manually.
|
||||||
|
"""
|
||||||
|
if self.max_batch_size is not None:
|
||||||
|
# already set by user
|
||||||
|
return
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
|
||||||
|
self.max_batch_size = 8
|
||||||
|
|
||||||
|
if 40 < total_mem <= 60:
|
||||||
|
self.max_batch_size = 16
|
||||||
|
elif 60 < total_mem <= 80:
|
||||||
|
self.max_batch_size = 32
|
||||||
|
logger.info(
|
||||||
|
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
self._init_batch_size()
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self):
|
def _verify_args(self):
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from .config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from colossalai.inference.core.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,7 @@ Colossal-Infer is a library for inference of LLMs and MLMs. It is built on top o
|
||||||
|
|
||||||
## Structures
|
## Structures
|
||||||
### Overview
|
### Overview
|
||||||
https://n4fyd3ptax.feishu.cn/docx/MhlmdHsGkoeoslx9fqucPO17n9b?openbrd=1&doc_app_id=501&blockId=WCGBdWI9hobOEsxkW5uc8HM6n3b&blockType=whiteboard&blockToken=Cca3wKWk7hPnJxbkCX6cMxPQnqd#WCGBdWI9hobOEsxkW5uc8HM6n3b
|
The main design will be released later on.
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
- [] design of structures
|
- [] design of structures
|
||||||
- [] Core components
|
- [] Core components
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
"""
|
|
||||||
The abstraction of request and sequence are defined here.
|
|
||||||
"""
|
|
|
@ -2,6 +2,10 @@ import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
|
"""
|
||||||
|
The abstraction of request and sequence are defined here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RequsetStatus(enum.Enum):
|
class RequsetStatus(enum.Enum):
|
||||||
"""The status of Sentences"""
|
"""The status of Sentences"""
|
||||||
|
@ -95,16 +99,16 @@ class Sequence:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchHandler:
|
class BatchInfo:
|
||||||
"""
|
"""
|
||||||
Information to be passed and used for a batch of sequences.
|
Information to be passed and used for a batch of sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences_set: Set[Sequence]
|
sequences_set: Set[Sequence]
|
||||||
block_table: Dict[int, int]
|
block_table: Dict[int, int] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
|
def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo":
|
||||||
"""
|
"""
|
||||||
Initializes inference batches by input sentence list.
|
Initializes inference batches by input sentence list.
|
||||||
|
|
||||||
|
@ -115,13 +119,13 @@ class BatchHandler:
|
||||||
block_table = {}
|
block_table = {}
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
if seq in sequences_set:
|
if seq in sequences_set:
|
||||||
print("The sequence is already in sequences_set.")
|
|
||||||
assert (
|
assert (
|
||||||
seq.request_id in block_table
|
seq.request_id in block_table.keys()
|
||||||
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
), "The sequence has been added to sequences_set, but it has not been added to block_table."
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
seq.request_id not in block_table
|
seq.request_id not in block_table.keys()
|
||||||
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
), "The sequence has not been added to sequences_set, but it is already in block_table."
|
||||||
|
|
||||||
sequences_set.add(seq)
|
sequences_set.add(seq)
|
||||||
|
@ -143,9 +147,9 @@ class BatchHandler:
|
||||||
"""
|
"""
|
||||||
Remove completed sentences from a batch.
|
Remove completed sentences from a batch.
|
||||||
"""
|
"""
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set.copy():
|
||||||
if seq.check_finish():
|
if seq.check_finish():
|
||||||
self.sequences_set.reomve(seq)
|
self.sequences_set.remove(seq)
|
||||||
del self.block_table[seq.request_id]
|
del self.block_table[seq.request_id]
|
||||||
|
|
||||||
def add_seqs(self, seqs: List[Sequence]) -> None:
|
def add_seqs(self, seqs: List[Sequence]) -> None:
|
|
@ -1,9 +1,10 @@
|
||||||
from colossalai.inference.core.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
|
from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence
|
||||||
|
|
||||||
|
|
||||||
def test_config_and_struct():
|
def test_config_and_inferenceData():
|
||||||
InferenceConfig("/llama")
|
config = InferenceConfig("/llama")
|
||||||
|
assert config.max_batch_size
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id=1,
|
request_id=1,
|
||||||
prompt="abc",
|
prompt="abc",
|
||||||
|
@ -27,11 +28,16 @@ def test_config_and_struct():
|
||||||
assert sequence.get_output_len() == 0
|
assert sequence.get_output_len() == 0
|
||||||
assert sequence.check_finish() == False
|
assert sequence.check_finish() == False
|
||||||
|
|
||||||
batch = BatchHandler.init_batch([sequence])
|
batch = BatchInfo.init_batch([sequence])
|
||||||
|
assert batch.block_table[sequence.request_id] == sequence.block_table_index
|
||||||
|
sequence.status = RequsetStatus.COMPLETED
|
||||||
batch.fliter_batch()
|
batch.fliter_batch()
|
||||||
|
assert batch.block_table == {}
|
||||||
batch.add_seqs([sequence2])
|
batch.add_seqs([sequence2])
|
||||||
|
assert batch.block_table[sequence2.request_id] == sequence2.block_table_index
|
||||||
batch.clear_batch()
|
batch.clear_batch()
|
||||||
|
assert batch.block_table == {}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_config_and_struct()
|
test_config_and_inferenceData()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import random
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama import LlamaConfig
|
from transformers.models.llama import LlamaConfig
|
||||||
|
|
||||||
from colossalai.inference.core.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
|
|
Loading…
Reference in New Issue