mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and structpull/5258/head
yuehuayingxueluo
12 months ago
committed by
FrankLeeeee
5 changed files with 279 additions and 34 deletions
@ -1,7 +0,0 @@
|
||||
""" |
||||
Our config consists of three parts: |
||||
1. model_config: The configuration for the model, including `model name`, 'model path' and self-defined layer. |
||||
2. parallel_config: The configuration for parallelize model, including `tp_size`,'pp_size', `world size`, `local rank`, `master port`, `master ip`. |
||||
3. cache_config: Configuration for initialize and manage kv cache, including `block size`, `block num` |
||||
For the convenience of users, we provide a unified config api for that wrapped all the configs. One can easily construct a colossal_config by setting the needed configs. |
||||
""" |
@ -0,0 +1,54 @@
|
||||
from typing import Optional, Union |
||||
from dataclasses import dataclass |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
@dataclass |
||||
class InferenceConfig: |
||||
"""The inference configuration. |
||||
|
||||
Args: |
||||
model: Path or nn.Module of this model. |
||||
tokenizer: Path of the tokenizer to use. |
||||
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. |
||||
trust_remote_code: Whether to trust remote code from huggingface. |
||||
max_batch_size: Maximum batch size. |
||||
max_output_len: Maximum output length. |
||||
max_input_len: Maximum input length. |
||||
block_size: The number of blocks in a logical block. |
||||
gpu_utilization_rate: Maximum GPU memory usage ratio. |
||||
dtype: The data type for weights and activations. |
||||
tp_size: Tensor parallel size. |
||||
pp_size: Pipeline parallel size. |
||||
max_seq_len: Maximum length of input sentence. |
||||
quant_mode: Quantization mode. |
||||
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. |
||||
""" |
||||
|
||||
model: Union[str, nn.Module] |
||||
tokenizer: str = None |
||||
tokenizer_mode: str = "auto" |
||||
trust_remote_code: bool = False |
||||
max_batch_size: int = 8 |
||||
max_output_len: int = 256 |
||||
max_input_len: int = 256 |
||||
block_size: int = 16 |
||||
gpu_utilization_rate: float = 0.7 |
||||
dtype: Union[str, torch.dtype] = torch.float32 |
||||
tp_size: int = 1 |
||||
pp_size: int = 1 |
||||
max_seq_len: Optional[int] = None |
||||
quant_mode: Optional[str] = None |
||||
revision: Optional[str] = None |
||||
|
||||
def __post_init__(self): |
||||
self._verify_args() |
||||
|
||||
def _verify_args(self): |
||||
if self.gpu_utilization_rate > 1.0: |
||||
raise ValueError( |
||||
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}." |
||||
) |
||||
if self.tokenizer_mode not in ["auto", "slow"]: |
||||
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}") |
@ -0,0 +1,169 @@
|
||||
import enum |
||||
from dataclasses import dataclass |
||||
from typing import Dict, List, Set |
||||
|
||||
|
||||
class RequsetStatus(enum.Enum): |
||||
"""The status of Sentences""" |
||||
|
||||
WAITING = enum.auto() |
||||
RUNNING = enum.auto() |
||||
ABORTED = enum.auto() |
||||
OVERLENGTH = enum.auto() |
||||
COMPLETED = enum.auto() |
||||
LENGTH_CAPPED = enum.auto() |
||||
|
||||
@staticmethod |
||||
def is_finished(status: "RequsetStatus") -> bool: |
||||
return status in [ |
||||
RequsetStatus.OVERLENGTH, |
||||
RequsetStatus.COMPLETED, |
||||
RequsetStatus.LENGTH_CAPPED, |
||||
] |
||||
|
||||
@staticmethod |
||||
def is_running(status: "RequsetStatus") -> bool: |
||||
return status == RequsetStatus.RUNNING |
||||
|
||||
@staticmethod |
||||
def is_waiting(status: "RequsetStatus") -> bool: |
||||
return status == RequsetStatus.WAITING |
||||
|
||||
|
||||
class Sequence: |
||||
"""Store information of input sequence. |
||||
|
||||
Args: |
||||
request_id: The ID of input sequence. |
||||
prompt: The prompt of input sequence. |
||||
token_id: The tokens ID of input sequence. |
||||
block_size: The block size of input sequence. |
||||
sample_params: The sample_params of input sequence. |
||||
block_table_index: The index of input sequence in block_table. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
request_id: int, |
||||
prompt: str, |
||||
token_id: List[int], |
||||
block_size: int, |
||||
sample_params, # SampleParams needs to be imported later. |
||||
block_table_index: int, |
||||
): |
||||
self.request_id = request_id |
||||
self.prompt = prompt |
||||
self.input_token_id = token_id |
||||
self.blokc_size = block_size |
||||
self.sample_params = sample_params |
||||
self.output_token_id = [] |
||||
self.status = RequsetStatus.WAITING |
||||
self.block_table_index = block_table_index |
||||
|
||||
def get_sentence_len(self) -> None: |
||||
""" |
||||
Get length of current sentence. |
||||
""" |
||||
return len(self.input_token_id) + len(self.output_token_id) |
||||
|
||||
def get_input_len(self) -> None: |
||||
""" |
||||
Get length of input sentence. |
||||
""" |
||||
return len(self.input_token_id) |
||||
|
||||
def get_output_len(self) -> None: |
||||
""" |
||||
Get output length of current sentence. |
||||
""" |
||||
return len(self.output_token_id) |
||||
|
||||
def check_finish(self) -> bool: |
||||
""" |
||||
Check whether inference is over. |
||||
""" |
||||
return RequsetStatus.is_finished(self.status) |
||||
|
||||
def __repr__(self) -> str: |
||||
return ( |
||||
f"Request ID(request_id={self.request_id}, " |
||||
f"prompt={self.prompt}, " |
||||
f"status={self.status.name}, " |
||||
f"sample_params={self.sample_params}, " |
||||
f"logical block number={len(self._logical_blocks)}" |
||||
) |
||||
|
||||
|
||||
@dataclass |
||||
class BatchHandler: |
||||
""" |
||||
Information to be passed and used for a batch of sequences. |
||||
""" |
||||
|
||||
sequences_set: Set[Sequence] |
||||
block_table: Dict[int, int] |
||||
|
||||
@classmethod |
||||
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler": |
||||
""" |
||||
Initializes inference batches by input sentence list. |
||||
|
||||
Args: |
||||
seqs (List[Sequence]): List of input sequence. |
||||
""" |
||||
sequences_set = set() |
||||
block_table = {} |
||||
for seq in seqs: |
||||
if seq in sequences_set: |
||||
print("The sequence is already in sequences_set.") |
||||
assert ( |
||||
seq.request_id in block_table |
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table." |
||||
continue |
||||
assert ( |
||||
seq.request_id not in block_table |
||||
), "The sequence has not been added to sequences_set, but it is already in block_table." |
||||
|
||||
sequences_set.add(seq) |
||||
block_table[seq.request_id] = seq.block_table_index |
||||
|
||||
return cls(sequences_set=sequences_set, block_table=block_table) |
||||
|
||||
def clear_batch(self) -> None: |
||||
""" |
||||
Clear sequence set and block table. |
||||
""" |
||||
for seq in self.sequences_set: |
||||
if not seq.check_finish(): |
||||
seq.status = RequsetStatus.ABORTED |
||||
self.sequences_set.clear() |
||||
self.block_table.clear() |
||||
|
||||
def fliter_batch(self) -> None: |
||||
""" |
||||
Remove completed sentences from a batch. |
||||
""" |
||||
for seq in self.sequences_set: |
||||
if seq.check_finish(): |
||||
self.sequences_set.reomve(seq) |
||||
del self.block_table[seq.request_id] |
||||
|
||||
def add_seqs(self, seqs: List[Sequence]) -> None: |
||||
""" |
||||
Add new sequence to batch |
||||
|
||||
Args: |
||||
seqs (List[Sequence]): The list of new sequences. |
||||
""" |
||||
for seq in seqs: |
||||
if seq in self.sequences_set: |
||||
print("The sequence is already in sequences_set.") |
||||
assert ( |
||||
seq.request_id in self.block_table |
||||
), "The sequence has been added to sequences_set, but it has not been added to block_table." |
||||
continue |
||||
assert ( |
||||
seq.request_id not in self.block_table |
||||
), "The sequence has not been added to sequences_set, but it is already in block_table." |
||||
self.sequences_set.add(seq) |
||||
self.block_table[seq.request_id] = seq.block_table_index |
@ -0,0 +1,37 @@
|
||||
from colossalai.inference.core.config import InferenceConfig |
||||
from colossalai.inference.core.inference_struct import BatchHandler, Sequence |
||||
|
||||
|
||||
def test_config_and_struct(): |
||||
InferenceConfig("/llama") |
||||
sequence = Sequence( |
||||
request_id=1, |
||||
prompt="abc", |
||||
token_id=[1, 2, 3], |
||||
block_size=16, |
||||
sample_params=None, |
||||
block_table_index=1, |
||||
) |
||||
|
||||
sequence2 = Sequence( |
||||
request_id=2, |
||||
prompt="bcd", |
||||
token_id=[4, 5, 6], |
||||
block_size=16, |
||||
sample_params=None, |
||||
block_table_index=2, |
||||
) |
||||
|
||||
assert sequence.get_sentence_len() == 3 |
||||
assert sequence.get_input_len() == 3 |
||||
assert sequence.get_output_len() == 0 |
||||
assert sequence.check_finish() == False |
||||
|
||||
batch = BatchHandler.init_batch([sequence]) |
||||
batch.fliter_batch() |
||||
batch.add_seqs([sequence2]) |
||||
batch.clear_batch() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_config_and_struct() |
Loading…
Reference in new issue