Browse Source

[Inference]Add BatchInferState, Sequence and InferConfig (#5149)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct
pull/5258/head
yuehuayingxueluo 12 months ago committed by FrankLeeeee
parent
commit
fab9b931d9
  1. 7
      colossalai/inference/config.py
  2. 54
      colossalai/inference/core/config.py
  3. 46
      colossalai/inference/core/engine.py
  4. 169
      colossalai/inference/core/inference_struct.py
  5. 37
      tests/test_infer/test_config_and_struct.py

7
colossalai/inference/config.py

@ -1,7 +0,0 @@
"""
Our config consists of three parts:
1. model_config: The configuration for the model, including `model name`, 'model path' and self-defined layer.
2. parallel_config: The configuration for parallelize model, including `tp_size`,'pp_size', `world size`, `local rank`, `master port`, `master ip`.
3. cache_config: Configuration for initialize and manage kv cache, including `block size`, `block num`
For the convenience of users, we provide a unified config api for that wrapped all the configs. One can easily construct a colossal_config by setting the needed configs.
"""

54
colossalai/inference/core/config.py

@ -0,0 +1,54 @@
from typing import Optional, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
@dataclass
class InferenceConfig:
"""The inference configuration.
Args:
model: Path or nn.Module of this model.
tokenizer: Path of the tokenizer to use.
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Whether to trust remote code from huggingface.
max_batch_size: Maximum batch size.
max_output_len: Maximum output length.
max_input_len: Maximum input length.
block_size: The number of blocks in a logical block.
gpu_utilization_rate: Maximum GPU memory usage ratio.
dtype: The data type for weights and activations.
tp_size: Tensor parallel size.
pp_size: Pipeline parallel size.
max_seq_len: Maximum length of input sentence.
quant_mode: Quantization mode.
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
"""
model: Union[str, nn.Module]
tokenizer: str = None
tokenizer_mode: str = "auto"
trust_remote_code: bool = False
max_batch_size: int = 8
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
gpu_utilization_rate: float = 0.7
dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1
pp_size: int = 1
max_seq_len: Optional[int] = None
quant_mode: Optional[str] = None
revision: Optional[str] = None
def __post_init__(self):
self._verify_args()
def _verify_args(self):
if self.gpu_utilization_rate > 1.0:
raise ValueError(
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
)
if self.tokenizer_mode not in ["auto", "slow"]:
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")

46
colossalai/inference/core/engine.py

@ -1,12 +1,14 @@
from logging import Logger
from typing import Optional
from .request_handler import RequestHandler
from transformers import AutoConfig
from .config import InferenceConfig
class InferEngine:
class InferenceEngine:
"""
InferEngine is the core component for Inference.
InferenceEngine is the core component for Inference.
It is responsible for launch the inference process, including:
- Initialize model and distributed training environment(if needed)
@ -15,37 +17,27 @@ class InferEngine:
- Log the generation process
Args:
colossal_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
model_config : The configuration for the model.
parallel_config: The configuration for parallelize model.
cache_config : Configuration for initialize and manage kv cache.
tokenizer (Tokenizer): The tokenizer to be used for inference.
use_logger (bool): Determine whether or not to log the generation process.
tokenizer: Path of the tokenizer to use.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
verbose (bool): Determine whether or not to log the generation process.
"""
def __init__(
self,
model_config,
cache_config,
parallel_config,
tokenizer,
use_logger: bool = False,
colossal_config: Optional["ColossalInferConfig"] = None,
tokenizer: str = None,
inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False,
) -> None:
assert colossal_config or (
model_config and cache_config and parallel_config
), "Please provide colossal_config or model_config, cache_config, parallel_config"
if colossal_config:
model_config, cache_config, parallel_config = colossal_config
self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self._verify_config()
assert inference_config, "Please provide inference_config."
self._init_model()
self.request_handler = RequestHandler(cache_config)
if use_logger:
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
)
if verbose:
self.logger = Logger()
def _init_model(self):

169
colossalai/inference/core/inference_struct.py

@ -0,0 +1,169 @@
import enum
from dataclasses import dataclass
from typing import Dict, List, Set
class RequsetStatus(enum.Enum):
"""The status of Sentences"""
WAITING = enum.auto()
RUNNING = enum.auto()
ABORTED = enum.auto()
OVERLENGTH = enum.auto()
COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto()
@staticmethod
def is_finished(status: "RequsetStatus") -> bool:
return status in [
RequsetStatus.OVERLENGTH,
RequsetStatus.COMPLETED,
RequsetStatus.LENGTH_CAPPED,
]
@staticmethod
def is_running(status: "RequsetStatus") -> bool:
return status == RequsetStatus.RUNNING
@staticmethod
def is_waiting(status: "RequsetStatus") -> bool:
return status == RequsetStatus.WAITING
class Sequence:
"""Store information of input sequence.
Args:
request_id: The ID of input sequence.
prompt: The prompt of input sequence.
token_id: The tokens ID of input sequence.
block_size: The block size of input sequence.
sample_params: The sample_params of input sequence.
block_table_index: The index of input sequence in block_table.
"""
def __init__(
self,
request_id: int,
prompt: str,
token_id: List[int],
block_size: int,
sample_params, # SampleParams needs to be imported later.
block_table_index: int,
):
self.request_id = request_id
self.prompt = prompt
self.input_token_id = token_id
self.blokc_size = block_size
self.sample_params = sample_params
self.output_token_id = []
self.status = RequsetStatus.WAITING
self.block_table_index = block_table_index
def get_sentence_len(self) -> None:
"""
Get length of current sentence.
"""
return len(self.input_token_id) + len(self.output_token_id)
def get_input_len(self) -> None:
"""
Get length of input sentence.
"""
return len(self.input_token_id)
def get_output_len(self) -> None:
"""
Get output length of current sentence.
"""
return len(self.output_token_id)
def check_finish(self) -> bool:
"""
Check whether inference is over.
"""
return RequsetStatus.is_finished(self.status)
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self._logical_blocks)}"
)
@dataclass
class BatchHandler:
"""
Information to be passed and used for a batch of sequences.
"""
sequences_set: Set[Sequence]
block_table: Dict[int, int]
@classmethod
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
"""
Initializes inference batches by input sentence list.
Args:
seqs (List[Sequence]): List of input sequence.
"""
sequences_set = set()
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set.add(seq)
block_table[seq.request_id] = seq.block_table_index
return cls(sequences_set=sequences_set, block_table=block_table)
def clear_batch(self) -> None:
"""
Clear sequence set and block table.
"""
for seq in self.sequences_set:
if not seq.check_finish():
seq.status = RequsetStatus.ABORTED
self.sequences_set.clear()
self.block_table.clear()
def fliter_batch(self) -> None:
"""
Remove completed sentences from a batch.
"""
for seq in self.sequences_set:
if seq.check_finish():
self.sequences_set.reomve(seq)
del self.block_table[seq.request_id]
def add_seqs(self, seqs: List[Sequence]) -> None:
"""
Add new sequence to batch
Args:
seqs (List[Sequence]): The list of new sequences.
"""
for seq in seqs:
if seq in self.sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in self.block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in self.block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
self.sequences_set.add(seq)
self.block_table[seq.request_id] = seq.block_table_index

37
tests/test_infer/test_config_and_struct.py

@ -0,0 +1,37 @@
from colossalai.inference.core.config import InferenceConfig
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
def test_config_and_struct():
InferenceConfig("/llama")
sequence = Sequence(
request_id=1,
prompt="abc",
token_id=[1, 2, 3],
block_size=16,
sample_params=None,
block_table_index=1,
)
sequence2 = Sequence(
request_id=2,
prompt="bcd",
token_id=[4, 5, 6],
block_size=16,
sample_params=None,
block_table_index=2,
)
assert sequence.get_sentence_len() == 3
assert sequence.get_input_len() == 3
assert sequence.get_output_len() == 0
assert sequence.check_finish() == False
batch = BatchHandler.init_batch([sequence])
batch.fliter_batch()
batch.add_seqs([sequence2])
batch.clear_batch()
if __name__ == "__main__":
test_config_and_struct()
Loading…
Cancel
Save