mirror of https://github.com/hpcaitech/ColossalAI
[inference] Async dynamic batching (#4894)
* finish input and output logic * add generate * test forward * 1pull/4905/head
parent
e0757c31fb
commit
fced140250
|
@ -102,17 +102,21 @@ class Batch:
|
|||
has_new_finish = True
|
||||
return has_new_finish
|
||||
|
||||
def filter_finished(self):
|
||||
def filter_finished(self)->List[Req]:
|
||||
"""
|
||||
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
|
||||
"""
|
||||
# TODO: the logic of return should be defined here.
|
||||
unfinished_req = []
|
||||
finished_req = []
|
||||
for req in self.reqs:
|
||||
if not req.has_generate_finished:
|
||||
unfinished_req.append(req)
|
||||
unfinished_req.append(req)
|
||||
else:
|
||||
finished_req.append(req)
|
||||
self.reqs = unfinished_req
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return finished_req
|
||||
|
||||
def is_clear(self):
|
||||
return len(self.reqs) == 0
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import time
|
||||
from typing import List
|
||||
import asyncio
|
||||
|
||||
from .dynamic_batching.infer_batch import InferBatch
|
||||
from .dynamic_batching.io_struct import Batch, Req
|
||||
|
@ -8,6 +9,8 @@ from .dynamic_batching.sampling_params import SamplingParams
|
|||
from .dynamic_batching.stats import Stats
|
||||
from .tensor_parallel import TPInferEngine
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
class DynamicBatchManager:
|
||||
def __init__(
|
||||
|
@ -54,6 +57,20 @@ class DynamicBatchManager:
|
|||
self.req_queue.append(req)
|
||||
return
|
||||
|
||||
def add_input(self, request_id, sampling_params, input_ids):
|
||||
"""
|
||||
Encode and Add new input to req queue. support one sequence input for now.
|
||||
"""
|
||||
prompt_ids = self.tokenizer.encode(input_ids)
|
||||
prompt_len = len(prompt_ids)
|
||||
if prompt_len > self.engine.max_input_len:
|
||||
raise ValueError(
|
||||
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
|
||||
)
|
||||
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
|
||||
self.add_req(prompt_ids, sampling_params, request_id)
|
||||
return
|
||||
|
||||
def abort(self, request_id):
|
||||
if self.running_batch is not None:
|
||||
for req in self.running_batch.reqs:
|
||||
|
@ -66,13 +83,15 @@ class DynamicBatchManager:
|
|||
req.aborted = True
|
||||
return
|
||||
|
||||
def loop_for_fwd(self):
|
||||
async def loop_for_fwd(self):
|
||||
"""
|
||||
The main loop for a dynamic batching process.
|
||||
"""
|
||||
counter_count = 0
|
||||
while self.running_batch is not None or self.req_queue.waiting_req_list:
|
||||
self._step()
|
||||
#self.running_batch is not None or self.req_queue.waiting_req_list
|
||||
while True:
|
||||
async for item in self._step():
|
||||
yield item
|
||||
counter_count += 1
|
||||
if self.running_batch is not None:
|
||||
if counter_count % self.mem_usage_interval == 0:
|
||||
|
@ -87,6 +106,26 @@ class DynamicBatchManager:
|
|||
if self.running_batch is None:
|
||||
time.sleep(0.1) # 10ms
|
||||
|
||||
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
if "llama" in tokenizer_name.lower() and use_fast == True:
|
||||
print(
|
||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||
"take a long time. To eliminate the initialization time, consider "
|
||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer. This is done automatically in Colossalai.")
|
||||
|
||||
tokenizer_name = _FAST_LLAMA_TOKENIZER
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
|
||||
except TypeError as e:
|
||||
use_fast = False
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
|
||||
|
||||
|
||||
def _step(self):
|
||||
"""
|
||||
Logic for handling requests
|
||||
|
@ -97,14 +136,14 @@ class DynamicBatchManager:
|
|||
if new_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_batch)
|
||||
self.running_batch = new_batch
|
||||
self._prefill_batch(self.running_batch)
|
||||
yield from self._prefill_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens = 0
|
||||
return
|
||||
|
||||
if self.has_wait_tokens < self.max_wait_tokens:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
self._decode_batch(self.running_batch)
|
||||
yield from self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
return
|
||||
|
@ -112,17 +151,18 @@ class DynamicBatchManager:
|
|||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_mini_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
||||
self._prefill_batch(new_mini_batch)
|
||||
yield from self._prefill_batch(new_mini_batch)
|
||||
if not new_mini_batch.is_clear():
|
||||
self._merge_batch(self.running_batch, new_mini_batch)
|
||||
self.running_batch.merge(new_mini_batch)
|
||||
self.has_wait_tokens = 0
|
||||
|
||||
else:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
self._decode_batch(self.running_batch)
|
||||
yield from self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
|
||||
|
||||
return
|
||||
|
||||
def _init_batch(self, batch: Batch, dtype="fp16"):
|
||||
|
@ -158,7 +198,8 @@ class DynamicBatchManager:
|
|||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||
self._handle_finish_req(batch, has_new_finished_req)
|
||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
||||
|
||||
# delete finished reqs
|
||||
|
||||
def _decode_batch(self, batch: Batch):
|
||||
|
@ -169,7 +210,7 @@ class DynamicBatchManager:
|
|||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||
self._handle_finish_req(batch, has_new_finished_req)
|
||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
||||
|
||||
def _filter_batch(self, batch: Batch):
|
||||
batch_id = batch.batch_id
|
||||
|
@ -201,11 +242,13 @@ class DynamicBatchManager:
|
|||
|
||||
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
||||
if has_new_finished_req:
|
||||
batch.filter_finished()
|
||||
finished_reqs=batch.filter_finished()
|
||||
if batch.is_clear():
|
||||
self._remove_batch(batch)
|
||||
else:
|
||||
self._filter_batch(batch)
|
||||
yield from self._output_process(finished_reqs)
|
||||
|
||||
|
||||
def _filter_runing_batch(self):
|
||||
if self.running_batch is not None and self.running_batch.is_clear():
|
||||
|
@ -218,26 +261,47 @@ class DynamicBatchManager:
|
|||
req.output_metadata_list.append(new_gen_metadata)
|
||||
return
|
||||
|
||||
async def _output_process(self, finished_reqs: List[Req]):
|
||||
"""
|
||||
Process the output of a batch.
|
||||
"""
|
||||
for req in finished_reqs:
|
||||
output = self.tokenizer.decode(req.output_ids)
|
||||
yield output, req.request_id, req.output_metadata_list
|
||||
|
||||
def clean_up(self):
|
||||
# this logic should be implemented in the future.
|
||||
pass
|
||||
|
||||
async def generate(self,request_id,prompt_id,sampling_params):
|
||||
"""
|
||||
Generate the output of a request.
|
||||
"""
|
||||
self.add_input(request_id,prompt_id,sampling_params)
|
||||
|
||||
|
||||
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||
# try:
|
||||
batch_manager = DynamicBatchManager(
|
||||
tp_engine=tp_engine,
|
||||
max_total_token_num=args.max_total_token_num,
|
||||
batch_max_tokens=args.batch_max_tokens,
|
||||
eos_id=args.eos_id,
|
||||
log_stats=not args.disable_log_stats,
|
||||
log_stats_interval=args.log_stats_interval,
|
||||
waiting_req_list=waiting_req_list,
|
||||
)
|
||||
try:
|
||||
batch_manager = DynamicBatchManager(
|
||||
tp_engine=tp_engine,
|
||||
max_total_token_num=args.max_total_token_num,
|
||||
batch_max_tokens=args.batch_max_tokens,
|
||||
eos_id=args.eos_id,
|
||||
log_stats=not args.disable_log_stats,
|
||||
log_stats_interval=args.log_stats_interval,
|
||||
waiting_req_list=waiting_req_list,
|
||||
)
|
||||
|
||||
# except Exception:
|
||||
# batch_manager.clean_up()
|
||||
# raise
|
||||
except Exception:
|
||||
batch_manager.clean_up()
|
||||
raise
|
||||
|
||||
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
|
||||
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
|
||||
|
||||
batch_manager.loop_for_fwd()
|
||||
return
|
||||
asyncio.run(prod_task)
|
||||
|
||||
for item in batch_manager.loop_for_fwd():
|
||||
print(item)
|
||||
|
||||
return batch_manager
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
import asyncio
|
||||
|
||||
shared_list = []
|
||||
|
||||
async def producer():
|
||||
for i in range(5):
|
||||
await asyncio.sleep(1) # 模拟异步获取数据的操作
|
||||
shared_list.append(i)
|
||||
print(f"Produced {i}")
|
||||
|
||||
async def consumer():
|
||||
last_index = 0
|
||||
while True:
|
||||
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
|
||||
if last_index < len(shared_list):
|
||||
item = shared_list[last_index]
|
||||
print(f"Consumed {item}")
|
||||
yield item
|
||||
last_index += 1
|
||||
|
||||
async def main():
|
||||
# 创建生产者和消费者任务
|
||||
prod_task = asyncio.create_task(producer())
|
||||
|
||||
# 等待生产者任务完成
|
||||
await prod_task
|
||||
|
||||
async for data in consumer():
|
||||
print(data)
|
||||
# 为了示例的目的,我们只等待一段时间,然后停止消费者
|
||||
await asyncio.sleep(5)
|
||||
|
||||
asyncio.run(main())
|
|
@ -42,7 +42,7 @@ def run():
|
|||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
waiting_list.append(req4)
|
||||
|
||||
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
@ -50,7 +50,13 @@ def run():
|
|||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
manager._set_tokenizer(tokenizer_name = model.__class__.__name__)
|
||||
result_generator = manager.loop_for_fwd()
|
||||
for result in result_generator:
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
def check_dynamic_forward(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue