ColossalAI/colossalai/legacy/inference/dynamic_batching/req_queue.py

74 lines
2.7 KiB
Python
Raw Normal View History

[Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
2023-10-30 02:52:19 +00:00
# Adapted from https://github.com/ModelTC/lightllm
import uuid
from typing import List
import numpy as np
from .io_struct import Batch, Req
class ReqQueue:
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
self.max_total_tokens = max_total_tokens
assert batch_max_tokens is not None
self.batch_max_tokens = batch_max_tokens
self.running_max_req_size = running_max_req_size
self.waiting_req_list: List[Req] = waiting_req_list
def append(self, req):
self.waiting_req_list.append(req)
return
def _init_cache_list(self, current_batch: Batch):
if current_batch is not None:
self.cache_len_list = [
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
for req in current_batch.reqs
]
else:
self.cache_len_list = []
# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req):
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
# NOTE: change here < to <=
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
def generate_new_batch(self, current_batch: Batch = None):
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
return None
self._init_cache_list(current_batch)
can_run_list = []
new_batch_total_tokens = 0
aborted_count = 0
for req in self.waiting_req_list:
flag = self._can_add_new_req(req)
if req.aborted:
aborted_count += 1
continue
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
can_run_list.append(req)
new_batch_total_tokens += req.input_len
else:
break
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().hex, can_run_list)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None
def __len__(self):
return self.waiting_req_list.__len__()