2023-12-07 06:34:01 +00:00
|
|
|
import enum
|
|
|
|
from dataclasses import dataclass
|
2023-12-26 13:34:27 +00:00
|
|
|
from typing import Any, List, Tuple, Union
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from ordered_set import OrderedSet
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
logger = get_dist_logger(__name__)
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-12 09:22:41 +00:00
|
|
|
"""
|
|
|
|
The abstraction of request and sequence are defined here.
|
|
|
|
"""
|
|
|
|
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
class RequestStatus(enum.Enum):
|
|
|
|
"""
|
|
|
|
The status of Sentences
|
|
|
|
"""
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
# running status
|
2023-12-07 06:34:01 +00:00
|
|
|
WAITING = enum.auto()
|
2023-12-25 04:15:15 +00:00
|
|
|
RUNNING = enum.auto()
|
2023-12-07 06:34:01 +00:00
|
|
|
ABORTED = enum.auto()
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
# completion status
|
2023-12-07 06:34:01 +00:00
|
|
|
OVERLENGTH = enum.auto()
|
|
|
|
COMPLETED = enum.auto()
|
|
|
|
LENGTH_CAPPED = enum.auto()
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-18 02:40:47 +00:00
|
|
|
def is_finished(status: "RequestStatus") -> bool:
|
2023-12-07 06:34:01 +00:00
|
|
|
return status in [
|
2023-12-18 02:40:47 +00:00
|
|
|
RequestStatus.OVERLENGTH,
|
|
|
|
RequestStatus.COMPLETED,
|
|
|
|
RequestStatus.LENGTH_CAPPED,
|
2023-12-07 06:34:01 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-18 02:40:47 +00:00
|
|
|
def is_running(status: "RequestStatus") -> bool:
|
2023-12-25 04:15:15 +00:00
|
|
|
return status == RequestStatus.RUNNING
|
2023-12-07 06:34:01 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-12-18 02:40:47 +00:00
|
|
|
def is_waiting(status: "RequestStatus") -> bool:
|
|
|
|
return status == RequestStatus.WAITING
|
2023-12-07 06:34:01 +00:00
|
|
|
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
@dataclass
|
2023-12-07 06:34:01 +00:00
|
|
|
class Sequence:
|
|
|
|
"""Store information of input sequence.
|
|
|
|
|
|
|
|
Args:
|
2023-12-18 02:40:47 +00:00
|
|
|
request_id (int): The ID of input sequence.
|
|
|
|
prompt (str): The prompt of input sequence.
|
|
|
|
input_token_id (List[int]): The tokens ID of input sequence.
|
|
|
|
block_size (int): The block size of input sequence.
|
|
|
|
sample_params (SampleParams): The sample_params of input sequence.
|
|
|
|
block_table (torch.Tensor): The index of input sequence in block_table.
|
|
|
|
eos_token_id (int): The eos token id for this inference process.
|
|
|
|
max_output_len (int): Maximum output length.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
request_id: int
|
|
|
|
prompt: str
|
|
|
|
input_token_id: List[int]
|
|
|
|
block_size: int
|
2023-12-25 04:15:15 +00:00
|
|
|
sample_params: Any # SampleParams needs to be imported later.
|
2023-12-18 02:40:47 +00:00
|
|
|
block_table: torch.Tensor
|
|
|
|
eos_token_id: int
|
|
|
|
max_output_len: int = 256
|
|
|
|
|
|
|
|
def __post_init__(self):
|
2023-12-07 06:34:01 +00:00
|
|
|
self.output_token_id = []
|
2023-12-18 02:40:47 +00:00
|
|
|
self.status = RequestStatus.WAITING
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
@property
|
|
|
|
def sentence_len(self) -> int:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Get length of current sentence.
|
|
|
|
"""
|
|
|
|
return len(self.input_token_id) + len(self.output_token_id)
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
@property
|
|
|
|
def input_len(self) -> int:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Get length of input sentence.
|
|
|
|
"""
|
|
|
|
return len(self.input_token_id)
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
@property
|
|
|
|
def output_len(self) -> int:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
Get length of output sentence.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
return len(self.output_token_id)
|
|
|
|
|
|
|
|
def check_finish(self) -> bool:
|
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
Check whether the inference is finished.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: Whether the inference is finished.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
if RequestStatus.is_finished(self.status):
|
|
|
|
return True
|
|
|
|
|
|
|
|
if self.output_token_id:
|
2024-01-02 05:02:20 +00:00
|
|
|
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
|
2023-12-18 02:40:47 +00:00
|
|
|
self.status = RequestStatus.COMPLETED
|
|
|
|
return True
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.request_id)
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
def mark_running(self) -> None:
|
|
|
|
"""
|
|
|
|
Set status for prefill reqs.
|
|
|
|
"""
|
|
|
|
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
|
|
|
|
self.status = RequestStatus.RUNNING
|
|
|
|
|
|
|
|
def mark_finished(self) -> None:
|
|
|
|
"""
|
|
|
|
Set status for finished reqs.
|
|
|
|
"""
|
|
|
|
self.status = RequestStatus.COMPLETED
|
|
|
|
|
|
|
|
def mark_aborted(self) -> None:
|
|
|
|
"""
|
|
|
|
Set status for aborted reqs.
|
|
|
|
"""
|
|
|
|
self.status = RequestStatus.ABORTED
|
|
|
|
|
2024-01-15 09:50:46 +00:00
|
|
|
def recycle(self) -> None:
|
|
|
|
"""
|
|
|
|
Recycle a running sequnce to waiitting list
|
|
|
|
"""
|
|
|
|
assert (
|
|
|
|
not self.status.is_finished and not self.status == RequestStatus.ABORTED
|
|
|
|
), "The running sequence \
|
|
|
|
is already done but it still in running list"
|
|
|
|
self.status = RequestStatus.WAITING
|
|
|
|
|
2023-12-07 06:34:01 +00:00
|
|
|
def __repr__(self) -> str:
|
|
|
|
return (
|
2023-12-26 13:34:27 +00:00
|
|
|
f"(request_id={self.request_id}, "
|
2023-12-07 06:34:01 +00:00
|
|
|
f"prompt={self.prompt}, "
|
|
|
|
f"status={self.status.name}, "
|
2023-12-25 04:15:15 +00:00
|
|
|
f"sample_params={self.sample_params}, "
|
2023-12-26 13:34:27 +00:00
|
|
|
f"logical_block_number={self.block_table.shape[0]},"
|
|
|
|
f"input_len={self.input_len}),"
|
|
|
|
f"output_len={self.output_len})"
|
2023-12-07 06:34:01 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2023-12-12 09:22:41 +00:00
|
|
|
class BatchInfo:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Information to be passed and used for a batch of sequences.
|
|
|
|
"""
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
sequences_set: OrderedSet["Sequence"] = None
|
|
|
|
is_prompts: bool = True
|
2023-12-26 13:34:27 +00:00
|
|
|
device: torch.device = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if self.device is None:
|
|
|
|
self.device = torch.cuda.current_device()
|
|
|
|
if self.sequences_set is None:
|
|
|
|
self.sequences_set = OrderedSet()
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
def init_batch(self, seqs: List["Sequence"] = None):
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Initializes inference batches by input sentence list.
|
|
|
|
|
|
|
|
Args:
|
2023-12-18 02:40:47 +00:00
|
|
|
seqs (List["Sequence"]): List of input sequence.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
2023-12-12 09:22:41 +00:00
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
if seqs is not None:
|
|
|
|
if not isinstance(seqs, list):
|
|
|
|
seqs = [seqs]
|
|
|
|
for seq in seqs:
|
2023-12-26 13:34:27 +00:00
|
|
|
if seq in self.sequences_set:
|
2023-12-18 02:40:47 +00:00
|
|
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
|
|
|
continue
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
self.sequences_set.add(seq)
|
2023-12-18 02:40:47 +00:00
|
|
|
|
2023-12-25 06:07:43 +00:00
|
|
|
def get_block_table_tensor(self) -> None:
|
2023-12-18 02:40:47 +00:00
|
|
|
tesnor_list = []
|
2023-12-25 06:07:43 +00:00
|
|
|
block_table = None
|
2023-12-18 02:40:47 +00:00
|
|
|
for seq in self.sequences_set:
|
|
|
|
block_table = seq.block_table
|
2023-12-26 13:34:27 +00:00
|
|
|
assert (
|
|
|
|
block_table is not None
|
|
|
|
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
2023-12-18 02:40:47 +00:00
|
|
|
tesnor_list.append(seq.block_table)
|
2023-12-25 06:07:43 +00:00
|
|
|
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
2023-12-26 13:34:27 +00:00
|
|
|
block_table = torch.stack(tesnor_list)
|
2023-12-25 06:07:43 +00:00
|
|
|
return block_table
|
2023-12-07 06:34:01 +00:00
|
|
|
|
|
|
|
def clear_batch(self) -> None:
|
|
|
|
"""
|
|
|
|
Clear sequence set and block table.
|
|
|
|
"""
|
|
|
|
for seq in self.sequences_set:
|
|
|
|
if not seq.check_finish():
|
2023-12-18 02:40:47 +00:00
|
|
|
seq.status = RequestStatus.ABORTED
|
2023-12-07 06:34:01 +00:00
|
|
|
self.sequences_set.clear()
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
def fliter_batch(self) -> List["Sequence"]:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Remove completed sentences from a batch.
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
List["Sequence"]: List of finished sequences.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
finish_seqs = []
|
|
|
|
for seq in self.sequences_set:
|
2023-12-07 06:34:01 +00:00
|
|
|
if seq.check_finish():
|
2023-12-18 02:40:47 +00:00
|
|
|
finish_seqs.append(seq)
|
|
|
|
for finish_seq in finish_seqs:
|
|
|
|
self.sequences_set.discard(finish_seq)
|
|
|
|
return finish_seqs
|
2023-12-07 06:34:01 +00:00
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
def abort_seq(self, seq: "Sequence") -> "Sequence":
|
|
|
|
"""
|
|
|
|
Remove sequence from the batch.
|
|
|
|
"""
|
|
|
|
if not seq.check_finish():
|
|
|
|
seq.status = RequestStatus.ABORTED
|
|
|
|
self.sequences_set.discard(seq)
|
|
|
|
return seq
|
|
|
|
|
|
|
|
def add_seqs(self, seqs: List["Sequence"]) -> None:
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
|
|
|
Add new sequence to batch
|
|
|
|
|
|
|
|
Args:
|
2023-12-18 02:40:47 +00:00
|
|
|
seqs (List["Sequence"]): The list of new sequences.
|
2023-12-07 06:34:01 +00:00
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
if not isinstance(seqs, list):
|
|
|
|
seqs = [seqs]
|
|
|
|
|
2023-12-07 06:34:01 +00:00
|
|
|
for seq in seqs:
|
2023-12-26 13:34:27 +00:00
|
|
|
if self.sequences_set and seq in self.sequences_set:
|
2023-12-18 02:40:47 +00:00
|
|
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
2023-12-07 06:34:01 +00:00
|
|
|
continue
|
|
|
|
self.sequences_set.add(seq)
|
2023-12-18 02:40:47 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
@property
|
2023-12-18 02:40:47 +00:00
|
|
|
def is_empty(self) -> None:
|
|
|
|
"""
|
|
|
|
Check whether sequences_set is empty.
|
|
|
|
"""
|
|
|
|
return not self.sequences_set
|
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
2023-12-18 02:40:47 +00:00
|
|
|
"""
|
|
|
|
Add an output token for each sentence in the batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tokens (List[int]): A batch of tokens
|
|
|
|
"""
|
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
if isinstance(tokens, torch.Tensor):
|
|
|
|
tokens = tokens.tolist()
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
|
|
|
|
|
|
|
for seq, token in zip(self.sequences_set, tokens):
|
|
|
|
if not isinstance(token, list):
|
|
|
|
if not isinstance(token, int):
|
2024-01-09 05:52:53 +00:00
|
|
|
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
2023-12-18 02:40:47 +00:00
|
|
|
token = [token]
|
|
|
|
seq.output_token_id += token
|
|
|
|
seq.check_finish()
|
|
|
|
|
|
|
|
def get_batch_size(self) -> int:
|
|
|
|
"""
|
|
|
|
Get batch_size of this batch
|
|
|
|
"""
|
|
|
|
return len(self.sequences_set)
|
2023-12-25 06:07:43 +00:00
|
|
|
|
|
|
|
def get_batch_inputs(self) -> torch.LongTensor:
|
|
|
|
"""
|
|
|
|
Get bacth inputs for forward inference computation.
|
|
|
|
"""
|
|
|
|
input_list = []
|
|
|
|
|
|
|
|
for seq in self.sequences_set:
|
|
|
|
if self.is_prompts:
|
|
|
|
input_list.append(seq.input_token_id)
|
|
|
|
else:
|
|
|
|
input_list.append([seq.output_token_id[-1]])
|
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
2023-12-25 06:07:43 +00:00
|
|
|
|
|
|
|
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
|
|
|
"""
|
|
|
|
Flattening the input tokens.
|
|
|
|
"""
|
|
|
|
input_list = []
|
2023-12-26 13:34:27 +00:00
|
|
|
input_len_list = []
|
2023-12-25 06:07:43 +00:00
|
|
|
for seq in self.sequences_set:
|
|
|
|
if self.is_prompts:
|
|
|
|
input_list.extend(seq.input_token_id)
|
2023-12-26 13:34:27 +00:00
|
|
|
input_len_list.append(seq.sentence_len)
|
2023-12-25 06:07:43 +00:00
|
|
|
else:
|
|
|
|
input_list.append(seq.output_token_id[-1])
|
2023-12-26 13:34:27 +00:00
|
|
|
input_len_list.append(1)
|
|
|
|
|
|
|
|
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
2024-01-02 10:30:11 +00:00
|
|
|
input_len_list, dtype=torch.int, device=self.device
|
2023-12-26 13:34:27 +00:00
|
|
|
)
|
2023-12-25 06:07:43 +00:00
|
|
|
|
|
|
|
def get_sequence_lengths(self):
|
|
|
|
"""
|
|
|
|
Get the input_len of each sentence in this batch.
|
|
|
|
"""
|
|
|
|
len_list = []
|
|
|
|
for seq in self.sequences_set:
|
2023-12-26 13:34:27 +00:00
|
|
|
len_list.append(seq.sentence_len)
|
|
|
|
|
|
|
|
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
|
|
|
|
2024-01-08 04:35:06 +00:00
|
|
|
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
2024-01-17 08:03:10 +00:00
|
|
|
"""
|
|
|
|
Generate and return attention mask.
|
|
|
|
"""
|
2024-01-08 04:35:06 +00:00
|
|
|
past_values = []
|
|
|
|
|
|
|
|
for seq in self.sequences_set:
|
|
|
|
past_values.append(seq.input_token_id + seq.output_token_id)
|
|
|
|
|
2024-01-17 08:03:10 +00:00
|
|
|
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
|
|
|
|
|
|
|
if torch.any(attn_mask == 0):
|
|
|
|
return attn_mask
|
|
|
|
else:
|
|
|
|
return None
|
2024-01-08 04:35:06 +00:00
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|