|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import time |
|
|
|
|
from typing import List |
|
|
|
|
import asyncio |
|
|
|
|
|
|
|
|
|
from .dynamic_batching.infer_batch import InferBatch |
|
|
|
|
from .dynamic_batching.io_struct import Batch, Req |
|
|
|
@ -8,6 +9,8 @@ from .dynamic_batching.sampling_params import SamplingParams
|
|
|
|
|
from .dynamic_batching.stats import Stats |
|
|
|
|
from .tensor_parallel import TPInferEngine |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" |
|
|
|
|
|
|
|
|
|
class DynamicBatchManager: |
|
|
|
|
def __init__( |
|
|
|
@ -54,6 +57,20 @@ class DynamicBatchManager:
|
|
|
|
|
self.req_queue.append(req) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def add_input(self, request_id, sampling_params, input_ids): |
|
|
|
|
""" |
|
|
|
|
Encode and Add new input to req queue. support one sequence input for now. |
|
|
|
|
""" |
|
|
|
|
prompt_ids = self.tokenizer.encode(input_ids) |
|
|
|
|
prompt_len = len(prompt_ids) |
|
|
|
|
if prompt_len > self.engine.max_input_len: |
|
|
|
|
raise ValueError( |
|
|
|
|
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" |
|
|
|
|
) |
|
|
|
|
sampling_params.stop_sentences_to_token_ids(self.tokenizer) |
|
|
|
|
self.add_req(prompt_ids, sampling_params, request_id) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def abort(self, request_id): |
|
|
|
|
if self.running_batch is not None: |
|
|
|
|
for req in self.running_batch.reqs: |
|
|
|
@ -66,13 +83,15 @@ class DynamicBatchManager:
|
|
|
|
|
req.aborted = True |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def loop_for_fwd(self): |
|
|
|
|
async def loop_for_fwd(self): |
|
|
|
|
""" |
|
|
|
|
The main loop for a dynamic batching process. |
|
|
|
|
""" |
|
|
|
|
counter_count = 0 |
|
|
|
|
while self.running_batch is not None or self.req_queue.waiting_req_list: |
|
|
|
|
self._step() |
|
|
|
|
#self.running_batch is not None or self.req_queue.waiting_req_list |
|
|
|
|
while True: |
|
|
|
|
async for item in self._step(): |
|
|
|
|
yield item |
|
|
|
|
counter_count += 1 |
|
|
|
|
if self.running_batch is not None: |
|
|
|
|
if counter_count % self.mem_usage_interval == 0: |
|
|
|
@ -87,6 +106,26 @@ class DynamicBatchManager:
|
|
|
|
|
if self.running_batch is None: |
|
|
|
|
time.sleep(0.1) # 10ms |
|
|
|
|
|
|
|
|
|
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): |
|
|
|
|
if tokenizer is not None: |
|
|
|
|
self.tokenizer = tokenizer |
|
|
|
|
else: |
|
|
|
|
if "llama" in tokenizer_name.lower() and use_fast == True: |
|
|
|
|
print( |
|
|
|
|
"For some LLaMA-based models, initializing the fast tokenizer may " |
|
|
|
|
"take a long time. To eliminate the initialization time, consider " |
|
|
|
|
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " |
|
|
|
|
"tokenizer. This is done automatically in Colossalai.") |
|
|
|
|
|
|
|
|
|
tokenizer_name = _FAST_LLAMA_TOKENIZER |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) |
|
|
|
|
except TypeError as e: |
|
|
|
|
use_fast = False |
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _step(self): |
|
|
|
|
""" |
|
|
|
|
Logic for handling requests |
|
|
|
@ -97,14 +136,14 @@ class DynamicBatchManager:
|
|
|
|
|
if new_batch is not None: |
|
|
|
|
self.stats_tool.count_prompt_tokens(new_batch) |
|
|
|
|
self.running_batch = new_batch |
|
|
|
|
self._prefill_batch(self.running_batch) |
|
|
|
|
yield from self._prefill_batch(self.running_batch) |
|
|
|
|
self._filter_runing_batch() |
|
|
|
|
self.has_wait_tokens = 0 |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
if self.has_wait_tokens < self.max_wait_tokens: |
|
|
|
|
self.stats_tool.count_output_tokens(self.running_batch) |
|
|
|
|
self._decode_batch(self.running_batch) |
|
|
|
|
yield from self._decode_batch(self.running_batch) |
|
|
|
|
self._filter_runing_batch() |
|
|
|
|
self.has_wait_tokens += 1 |
|
|
|
|
return |
|
|
|
@ -112,17 +151,18 @@ class DynamicBatchManager:
|
|
|
|
|
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) |
|
|
|
|
if new_mini_batch is not None: |
|
|
|
|
self.stats_tool.count_prompt_tokens(new_mini_batch) |
|
|
|
|
self._prefill_batch(new_mini_batch) |
|
|
|
|
yield from self._prefill_batch(new_mini_batch) |
|
|
|
|
if not new_mini_batch.is_clear(): |
|
|
|
|
self._merge_batch(self.running_batch, new_mini_batch) |
|
|
|
|
self.running_batch.merge(new_mini_batch) |
|
|
|
|
self.has_wait_tokens = 0 |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
self.stats_tool.count_output_tokens(self.running_batch) |
|
|
|
|
self._decode_batch(self.running_batch) |
|
|
|
|
yield from self._decode_batch(self.running_batch) |
|
|
|
|
self._filter_runing_batch() |
|
|
|
|
self.has_wait_tokens += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
def _init_batch(self, batch: Batch, dtype="fp16"): |
|
|
|
@ -158,7 +198,8 @@ class DynamicBatchManager:
|
|
|
|
|
req_to_out_token_id = ans |
|
|
|
|
self._add_token_id_to_req(batch, req_to_out_token_id) |
|
|
|
|
has_new_finished_req = batch.mark_finished_req(self.eos_id) |
|
|
|
|
self._handle_finish_req(batch, has_new_finished_req) |
|
|
|
|
yield from self._handle_finish_req(batch, has_new_finished_req) |
|
|
|
|
|
|
|
|
|
# delete finished reqs |
|
|
|
|
|
|
|
|
|
def _decode_batch(self, batch: Batch): |
|
|
|
@ -169,7 +210,7 @@ class DynamicBatchManager:
|
|
|
|
|
req_to_out_token_id = ans |
|
|
|
|
self._add_token_id_to_req(batch, req_to_out_token_id) |
|
|
|
|
has_new_finished_req = batch.mark_finished_req(self.eos_id) |
|
|
|
|
self._handle_finish_req(batch, has_new_finished_req) |
|
|
|
|
yield from self._handle_finish_req(batch, has_new_finished_req) |
|
|
|
|
|
|
|
|
|
def _filter_batch(self, batch: Batch): |
|
|
|
|
batch_id = batch.batch_id |
|
|
|
@ -201,11 +242,13 @@ class DynamicBatchManager:
|
|
|
|
|
|
|
|
|
|
def _handle_finish_req(self, batch: Batch, has_new_finished_req): |
|
|
|
|
if has_new_finished_req: |
|
|
|
|
batch.filter_finished() |
|
|
|
|
finished_reqs=batch.filter_finished() |
|
|
|
|
if batch.is_clear(): |
|
|
|
|
self._remove_batch(batch) |
|
|
|
|
else: |
|
|
|
|
self._filter_batch(batch) |
|
|
|
|
yield from self._output_process(finished_reqs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_runing_batch(self): |
|
|
|
|
if self.running_batch is not None and self.running_batch.is_clear(): |
|
|
|
@ -218,26 +261,47 @@ class DynamicBatchManager:
|
|
|
|
|
req.output_metadata_list.append(new_gen_metadata) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
async def _output_process(self, finished_reqs: List[Req]): |
|
|
|
|
""" |
|
|
|
|
Process the output of a batch. |
|
|
|
|
""" |
|
|
|
|
for req in finished_reqs: |
|
|
|
|
output = self.tokenizer.decode(req.output_ids) |
|
|
|
|
yield output, req.request_id, req.output_metadata_list |
|
|
|
|
|
|
|
|
|
def clean_up(self): |
|
|
|
|
# this logic should be implemented in the future. |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
async def generate(self,request_id,prompt_id,sampling_params): |
|
|
|
|
""" |
|
|
|
|
Generate the output of a request. |
|
|
|
|
""" |
|
|
|
|
self.add_input(request_id,prompt_id,sampling_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_dynamic_batching(args, tp_engine, waiting_req_list): |
|
|
|
|
# try: |
|
|
|
|
batch_manager = DynamicBatchManager( |
|
|
|
|
tp_engine=tp_engine, |
|
|
|
|
max_total_token_num=args.max_total_token_num, |
|
|
|
|
batch_max_tokens=args.batch_max_tokens, |
|
|
|
|
eos_id=args.eos_id, |
|
|
|
|
log_stats=not args.disable_log_stats, |
|
|
|
|
log_stats_interval=args.log_stats_interval, |
|
|
|
|
waiting_req_list=waiting_req_list, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# except Exception: |
|
|
|
|
# batch_manager.clean_up() |
|
|
|
|
# raise |
|
|
|
|
|
|
|
|
|
batch_manager.loop_for_fwd() |
|
|
|
|
return |
|
|
|
|
try: |
|
|
|
|
batch_manager = DynamicBatchManager( |
|
|
|
|
tp_engine=tp_engine, |
|
|
|
|
max_total_token_num=args.max_total_token_num, |
|
|
|
|
batch_max_tokens=args.batch_max_tokens, |
|
|
|
|
eos_id=args.eos_id, |
|
|
|
|
log_stats=not args.disable_log_stats, |
|
|
|
|
log_stats_interval=args.log_stats_interval, |
|
|
|
|
waiting_req_list=waiting_req_list, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
except Exception: |
|
|
|
|
batch_manager.clean_up() |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) |
|
|
|
|
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) |
|
|
|
|
|
|
|
|
|
asyncio.run(prod_task) |
|
|
|
|
|
|
|
|
|
for item in batch_manager.loop_for_fwd(): |
|
|
|
|
print(item) |
|
|
|
|
|
|
|
|
|
return batch_manager |
|
|
|
|