mirror of https://github.com/hpcaitech/ColossalAI
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
|
# Adapted from https://github.com/ModelTC/lightllm
|
||
|
|
||
|
import uuid
|
||
|
from typing import List
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .io_struct import Batch, Req
|
||
|
|
||
|
|
||
|
class ReqQueue:
|
||
|
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
|
||
|
self.max_total_tokens = max_total_tokens
|
||
|
assert batch_max_tokens is not None
|
||
|
self.batch_max_tokens = batch_max_tokens
|
||
|
self.running_max_req_size = running_max_req_size
|
||
|
self.waiting_req_list: List[Req] = waiting_req_list
|
||
|
|
||
|
def append(self, req):
|
||
|
self.waiting_req_list.append(req)
|
||
|
return
|
||
|
|
||
|
def _init_cache_list(self, current_batch: Batch):
|
||
|
if current_batch is not None:
|
||
|
self.cache_len_list = [
|
||
|
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
|
||
|
for req in current_batch.reqs
|
||
|
]
|
||
|
else:
|
||
|
self.cache_len_list = []
|
||
|
|
||
|
# @calculate_time(show=True, min_cost_ms=0.1)
|
||
|
def _can_add_new_req(self, req):
|
||
|
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
|
||
|
self.cache_len_list.sort(key=lambda x: -x[1])
|
||
|
|
||
|
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
|
||
|
# assert left_out_len_array.min() >= 0
|
||
|
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
|
||
|
cum_run_len_array = np.cumsum(has_run_len_array)
|
||
|
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
|
||
|
|
||
|
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
|
||
|
# NOTE: change here < to <=
|
||
|
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
|
||
|
|
||
|
def generate_new_batch(self, current_batch: Batch = None):
|
||
|
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
|
||
|
return None
|
||
|
self._init_cache_list(current_batch)
|
||
|
can_run_list = []
|
||
|
new_batch_total_tokens = 0
|
||
|
aborted_count = 0
|
||
|
for req in self.waiting_req_list:
|
||
|
flag = self._can_add_new_req(req)
|
||
|
if req.aborted:
|
||
|
aborted_count += 1
|
||
|
continue
|
||
|
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
|
||
|
can_run_list.append(req)
|
||
|
new_batch_total_tokens += req.input_len
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
if len(can_run_list) != 0:
|
||
|
new_batch = Batch(uuid.uuid4().hex, can_run_list)
|
||
|
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
|
||
|
return new_batch
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.waiting_req_list.__len__()
|