from typing import List from .dynamic_batching.io_struct import Batch, Req, RequestOutput from .manager import DynamicBatchManager from .tensor_parallel import TPInferEngine class Async_DynamicBatchManager(DynamicBatchManager): def __init__( self, tp_engine: TPInferEngine, max_total_token_num: int, batch_max_tokens: int, model: str, tokenizer=None, eos_id=None, log_stats=True, log_stats_interval=10, running_batch: Batch = None, waiting_req_list: List = [], ): """ Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch waiting_req_list : list of waiting requests, initialized before dynamic batch manager """ super().__init__( tp_engine, max_total_token_num, batch_max_tokens, model, tokenizer, eos_id, log_stats, log_stats_interval, running_batch, waiting_req_list, ) def _step(self): """ Logic for handling requests """ has_new_finished = False if self.running_batch is None: new_batch = self.req_queue.generate_new_batch(self.running_batch) if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch has_new_finished, outputs = self._prefill_batch(self.running_batch) self._filter_running_batch() self.has_wait_tokens = 0 else: if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) self._filter_running_batch() self.has_wait_tokens += 1 else: new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) has_new_finished, outputs = self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 else: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) self._filter_running_batch() self.has_wait_tokens += 1 if has_new_finished: return outputs return None def _prefill_batch(self, batch): """ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. """ self._init_batch(batch) # TODO: figure out if cache and batch id is needed ans = self.engine._prefill_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) outputs = self._handle_finish_req(batch, has_new_finished_req) return has_new_finished_req, outputs # delete finished reqs def _decode_batch(self, batch: Batch): """ Decoding process """ ans = self.engine._decode_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) outputs = self._handle_finish_req(batch, has_new_finished_req) return has_new_finished_req, outputs def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) return self._output_process(finished_reqs) return None def _output_process(self, finished_reqs: List[Req]): """ Process the output of a batch. """ outputs = [] for req in finished_reqs: output = self.tokenizer.decode(req.output_ids) outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) return outputs def start_dynamic_batching(args, tp_engine, waiting_req_list): try: batch_manager = Async_DynamicBatchManager( tp_engine=tp_engine, max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, eos_id=args.eos_id, model=args.model, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, waiting_req_list=waiting_req_list, ) except Exception: raise Exception return batch_manager