mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.8 KiB
152 lines
5.8 KiB
from typing import List
|
|
|
|
from .dynamic_batching.io_struct import Batch, Req, RequestOutput
|
|
from .manager import DynamicBatchManager
|
|
from .tensor_parallel import TPInferEngine
|
|
|
|
|
|
class Async_DynamicBatchManager(DynamicBatchManager):
|
|
def __init__(
|
|
self,
|
|
tp_engine: TPInferEngine,
|
|
max_total_token_num: int,
|
|
batch_max_tokens: int,
|
|
model: str,
|
|
tokenizer=None,
|
|
eos_id=None,
|
|
log_stats=True,
|
|
log_stats_interval=10,
|
|
running_batch: Batch = None,
|
|
waiting_req_list: List = [],
|
|
):
|
|
"""
|
|
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
|
|
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
|
|
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
|
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
|
eos_id : The end token of a seq
|
|
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
|
|
log_stats : whether to log stats
|
|
log_stats_interval : log stats interval
|
|
running_batch : running batch
|
|
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
|
|
"""
|
|
super().__init__(
|
|
tp_engine,
|
|
max_total_token_num,
|
|
batch_max_tokens,
|
|
model,
|
|
tokenizer,
|
|
eos_id,
|
|
log_stats,
|
|
log_stats_interval,
|
|
running_batch,
|
|
waiting_req_list,
|
|
)
|
|
|
|
def _step(self):
|
|
"""
|
|
Logic for handling requests
|
|
"""
|
|
has_new_finished = False
|
|
if self.running_batch is None:
|
|
new_batch = self.req_queue.generate_new_batch(self.running_batch)
|
|
if new_batch is not None:
|
|
self.stats_tool.count_prompt_tokens(new_batch)
|
|
self.running_batch = new_batch
|
|
has_new_finished, outputs = self._prefill_batch(self.running_batch)
|
|
self._filter_runing_batch()
|
|
self.has_wait_tokens = 0
|
|
|
|
else:
|
|
if self.has_wait_tokens < self.max_wait_tokens:
|
|
self.stats_tool.count_output_tokens(self.running_batch)
|
|
has_new_finished, outputs = self._decode_batch(self.running_batch)
|
|
self._filter_runing_batch()
|
|
self.has_wait_tokens += 1
|
|
|
|
else:
|
|
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
|
if new_mini_batch is not None:
|
|
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
|
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
|
|
if not new_mini_batch.is_clear():
|
|
self._merge_batch(self.running_batch, new_mini_batch)
|
|
self.running_batch.merge(new_mini_batch)
|
|
self.has_wait_tokens = 0
|
|
|
|
else:
|
|
self.stats_tool.count_output_tokens(self.running_batch)
|
|
has_new_finished, outputs = self._decode_batch(self.running_batch)
|
|
self._filter_runing_batch()
|
|
self.has_wait_tokens += 1
|
|
|
|
if has_new_finished:
|
|
return outputs
|
|
return None
|
|
|
|
def _prefill_batch(self, batch):
|
|
"""
|
|
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
|
"""
|
|
self._init_batch(batch)
|
|
|
|
# TODO: figure out if cache and batch id is needed
|
|
ans = self.engine._prefill_batch(batch.batch_id)
|
|
req_to_out_token_id = ans
|
|
self._add_token_id_to_req(batch, req_to_out_token_id)
|
|
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
|
outputs = self._handle_finish_req(batch, has_new_finished_req)
|
|
return has_new_finished_req, outputs
|
|
# delete finished reqs
|
|
|
|
def _decode_batch(self, batch: Batch):
|
|
"""
|
|
Decoding process
|
|
"""
|
|
ans = self.engine._decode_batch(batch.batch_id)
|
|
req_to_out_token_id = ans
|
|
self._add_token_id_to_req(batch, req_to_out_token_id)
|
|
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
|
outputs = self._handle_finish_req(batch, has_new_finished_req)
|
|
return has_new_finished_req, outputs
|
|
|
|
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
|
if has_new_finished_req:
|
|
finished_reqs = batch.filter_finished()
|
|
if batch.is_clear():
|
|
self._remove_batch(batch)
|
|
else:
|
|
self._filter_batch(batch)
|
|
return self._output_process(finished_reqs)
|
|
return None
|
|
|
|
def _output_process(self, finished_reqs: List[Req]):
|
|
"""
|
|
Process the output of a batch.
|
|
"""
|
|
outputs = []
|
|
for req in finished_reqs:
|
|
output = self.tokenizer.decode(req.output_ids)
|
|
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
|
|
return outputs
|
|
|
|
|
|
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
|
try:
|
|
batch_manager = Async_DynamicBatchManager(
|
|
tp_engine=tp_engine,
|
|
max_total_token_num=args.max_total_token_num,
|
|
batch_max_tokens=args.batch_max_tokens,
|
|
eos_id=args.eos_id,
|
|
model=args.model,
|
|
log_stats=not args.disable_log_stats,
|
|
log_stats_interval=args.log_stats_interval,
|
|
waiting_req_list=waiting_req_list,
|
|
)
|
|
|
|
except Exception:
|
|
raise Exception
|
|
|
|
return batch_manager
|