2023-12-01 09:31:31 +00:00
|
|
|
from typing import List
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
from colossalai.inference.struct import BatchInfo, Sequence
|
|
|
|
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2023-12-01 09:02:44 +00:00
|
|
|
class RequestHandler:
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
RequestHandler is the core for handling existing requests and updating current batch.
|
|
|
|
During generation process, we call schedule function each iteration to update current batch.
|
|
|
|
|
|
|
|
Args:
|
2023-12-18 02:40:47 +00:00
|
|
|
inference_config: Store the configuration information related to inference.
|
|
|
|
model_config: The huggingface model config.
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
def __init__(self, inference_config, model_config) -> None:
|
|
|
|
self.inference_config = inference_config
|
|
|
|
self.model_config = model_config
|
2023-12-01 09:02:44 +00:00
|
|
|
self._init_cache()
|
2023-12-18 02:40:47 +00:00
|
|
|
self.waiting_list: List["Sequence"] = []
|
|
|
|
self.running_list: List["Sequence"] = []
|
|
|
|
self.batch = BatchInfo.init_batch()
|
2023-12-01 09:02:44 +00:00
|
|
|
|
|
|
|
def _init_cache(self):
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
Initialize the cache manager with cache config.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def schedule(self):
|
|
|
|
"""
|
|
|
|
The main logic of request handler.
|
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
# The code below is only used for testing engine and will be modified.
|
|
|
|
if self.waiting_list:
|
|
|
|
self.running_list = self.waiting_list
|
|
|
|
self.batch.add_seqs(self.running_list)
|
|
|
|
return self.batch
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
def add_sequence(self, req_seq: "Sequence"):
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
Add the request to waiting list.
|
|
|
|
"""
|
2023-12-18 02:40:47 +00:00
|
|
|
self.waiting_list.append(req_seq)
|
2023-12-01 09:31:31 +00:00
|
|
|
|
|
|
|
def abort_sequence(self, seq_id: str):
|
|
|
|
"""
|
|
|
|
Abort the request. #TODO :implement this
|
|
|
|
"""
|
|
|
|
self._find_sequence(seq_id)
|
|
|
|
return
|
|
|
|
|
2023-12-18 02:40:47 +00:00
|
|
|
def _find_sequence(self, seq_id: str) -> "Sequence":
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
Find the request by seq_id.
|
|
|
|
"""
|
2023-12-01 09:02:44 +00:00
|
|
|
|
2023-12-01 09:31:31 +00:00
|
|
|
def check_unfinished_seqs(self) -> bool:
|
2023-12-18 02:40:47 +00:00
|
|
|
return len(self.waiting_list) != 0 or len(self.running_list) != 0
|
|
|
|
|
|
|
|
def update(self):
|
|
|
|
"""
|
|
|
|
Update the waiting list and running list.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# The code below is only used for testing engine and will be modified.
|
|
|
|
self.waiting_list = []
|
|
|
|
self.running_list = []
|
|
|
|
finished_sequences = list(self.batch.sequences_set)
|
|
|
|
|
|
|
|
self.batch.clear_batch()
|
|
|
|
return finished_sequences
|