[inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
pull/4905/head
Jianghai 1 year ago committed by GitHub
parent fced140250
commit fbf3c09e67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ from .sampling_params import SamplingParams
class Req: class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
self.request_id = request_id self.request_id = request_id
self.prompt_ids = prompt_ids self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids) self.input_len = len(prompt_ids)
@ -14,6 +14,7 @@ class Req:
self.output_metadata_list = [] self.output_metadata_list = []
self.has_generate_finished = False self.has_generate_finished = False
self.aborted = False self.aborted = False
self.prompts = prompts
def to_rpc_obj(self): def to_rpc_obj(self):
return { return {
@ -36,7 +37,11 @@ class Req:
if self.sample_params.stop_sequences is not None: if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences: for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids) stop_len = len(stop_token_ids)
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): if (
stop_len > 0
and len(self.output_ids) >= stop_len
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
):
return True return True
return False return False
@ -102,7 +107,7 @@ class Batch:
has_new_finish = True has_new_finish = True
return has_new_finish return has_new_finish
def filter_finished(self)->List[Req]: def filter_finished(self) -> List[Req]:
""" """
Filter finished requests from the batch, the finished ones will be removed from 'reqs'. Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
""" """

@ -1,6 +1,7 @@
import time
from typing import List
import asyncio import asyncio
from typing import List
from transformers import AutoTokenizer
from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.io_struct import Batch, Req
@ -9,9 +10,9 @@ from .dynamic_batching.sampling_params import SamplingParams
from .dynamic_batching.stats import Stats from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine from .tensor_parallel import TPInferEngine
from transformers import AutoTokenizer
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
class DynamicBatchManager: class DynamicBatchManager:
def __init__( def __init__(
self, self,
@ -19,6 +20,7 @@ class DynamicBatchManager:
max_total_token_num, max_total_token_num,
batch_max_tokens, batch_max_tokens,
eos_id, eos_id,
model,
log_stats=True, log_stats=True,
log_stats_interval=10, log_stats_interval=10,
running_batch: Batch = None, running_batch: Batch = None,
@ -30,6 +32,7 @@ class DynamicBatchManager:
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats log_stats : whether to log stats
log_stats_interval : log stats interval log_stats_interval : log stats interval
running_batch : running batch running_batch : running batch
@ -45,30 +48,30 @@ class DynamicBatchManager:
self.eos_id = eos_id self.eos_id = eos_id
self.has_wait_tokens = 0 self.has_wait_tokens = 0
self.max_wait_tokens = 10 self.max_wait_tokens = 10
self.model = model
self.stats_tool = Stats(log_stats, log_stats_interval) self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2 self.mem_usage_interval = log_stats_interval * 2
self._set_tokenizer(tokenizer_name=self.model)
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
""" """
Add new request to req queue, during initialization all requests are held in waiting list. Add new request to req queue, during initialization all requests are held in waiting list.
""" """
req = Req(request_id, prompt_ids, sampling_params) req = Req(request_id, prompt_ids, sampling_params, prompts)
self.req_queue.append(req) self.req_queue.append(req)
return return
def add_input(self, request_id, sampling_params, input_ids): async def add_input(self, request_id, sampling_params, prompts):
""" """
Encode and Add new input to req queue. support one sequence input for now. Encode and Add new input to req queue. support one sequence input for now.
""" """
prompt_ids = self.tokenizer.encode(input_ids) prompt_ids = self.tokenizer.encode(prompts)
prompt_len = len(prompt_ids) prompt_len = len(prompt_ids)
if prompt_len > self.engine.max_input_len: if prompt_len > self.engine.max_input_len:
raise ValueError( raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
)
sampling_params.stop_sentences_to_token_ids(self.tokenizer) sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(prompt_ids, sampling_params, request_id) self.add_req(request_id, prompt_ids, sampling_params, prompts)
return return
def abort(self, request_id): def abort(self, request_id):
@ -88,10 +91,15 @@ class DynamicBatchManager:
The main loop for a dynamic batching process. The main loop for a dynamic batching process.
""" """
counter_count = 0 counter_count = 0
#self.running_batch is not None or self.req_queue.waiting_req_list # self.running_batch is not None or self.req_queue.waiting_req_list
while True: while True:
async for item in self._step(): if self.running_batch is not None or self.req_queue.waiting_req_list:
yield item async for result in self._step():
yield result
else:
# need to wait for new requests
await asyncio.sleep(0.1)
continue
counter_count += 1 counter_count += 1
if self.running_batch is not None: if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0: if counter_count % self.mem_usage_interval == 0:
@ -103,30 +111,33 @@ class DynamicBatchManager:
) )
self.stats_tool.print_stats() self.stats_tool.print_stats()
if self.running_batch is None: def _set_tokenizer(
time.sleep(0.1) # 10ms self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
):
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
if tokenizer is not None: if tokenizer is not None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
if "llama" in tokenizer_name.lower() and use_fast == True: if "llama" in tokenizer_name.lower() and use_fast == True:
print( print(
"For some LLaMA-based models, initializing the fast tokenizer may " "For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider " "take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai.") "tokenizer. This is done automatically in Colossalai."
)
tokenizer_name = _FAST_LLAMA_TOKENIZER tokenizer_name = _FAST_LLAMA_TOKENIZER
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(
except TypeError as e: tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False use_fast = False
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
async def _step(self):
def _step(self):
""" """
Logic for handling requests Logic for handling requests
""" """
@ -136,14 +147,15 @@ class DynamicBatchManager:
if new_batch is not None: if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch) self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch) async for item in self._prefill_batch(self.running_batch):
yield item
self._filter_runing_batch() self._filter_runing_batch()
self.has_wait_tokens = 0 self.has_wait_tokens = 0
return return
if self.has_wait_tokens < self.max_wait_tokens: if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch) self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch) self._decode_batch(self.running_batch)
self._filter_runing_batch() self._filter_runing_batch()
self.has_wait_tokens += 1 self.has_wait_tokens += 1
return return
@ -151,7 +163,8 @@ class DynamicBatchManager:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None: if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch) self.stats_tool.count_prompt_tokens(new_mini_batch)
yield from self._prefill_batch(new_mini_batch) async for item in self._prefill_batch(new_mini_batch):
yield item
if not new_mini_batch.is_clear(): if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch) self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch) self.running_batch.merge(new_mini_batch)
@ -159,7 +172,8 @@ class DynamicBatchManager:
else: else:
self.stats_tool.count_output_tokens(self.running_batch) self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch) async for item in self._decode_batch(self.running_batch):
yield item
self._filter_runing_batch() self._filter_runing_batch()
self.has_wait_tokens += 1 self.has_wait_tokens += 1
@ -187,7 +201,7 @@ class DynamicBatchManager:
) )
self.engine.cache[batch_id] = batch_data self.engine.cache[batch_id] = batch_data
def _prefill_batch(self, batch): async def _prefill_batch(self, batch):
""" """
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
""" """
@ -198,11 +212,11 @@ class DynamicBatchManager:
req_to_out_token_id = ans req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id) self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id) has_new_finished_req = batch.mark_finished_req(self.eos_id)
yield from self._handle_finish_req(batch, has_new_finished_req) async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item
# delete finished reqs # delete finished reqs
def _decode_batch(self, batch: Batch): async def _decode_batch(self, batch: Batch):
""" """
Decoding process Decoding process
""" """
@ -210,7 +224,8 @@ class DynamicBatchManager:
req_to_out_token_id = ans req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id) self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id) has_new_finished_req = batch.mark_finished_req(self.eos_id)
yield from self._handle_finish_req(batch, has_new_finished_req) async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item
def _filter_batch(self, batch: Batch): def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id batch_id = batch.batch_id
@ -240,15 +255,15 @@ class DynamicBatchManager:
batch.free_self() batch.free_self()
del batch del batch
def _handle_finish_req(self, batch: Batch, has_new_finished_req): async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req: if has_new_finished_req:
finished_reqs=batch.filter_finished() finished_reqs = batch.filter_finished()
if batch.is_clear(): if batch.is_clear():
self._remove_batch(batch) self._remove_batch(batch)
else: else:
self._filter_batch(batch) self._filter_batch(batch)
yield from self._output_process(finished_reqs) async for item in self._output_process(finished_reqs):
yield item
def _filter_runing_batch(self): def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear(): if self.running_batch is not None and self.running_batch.is_clear():
@ -267,17 +282,23 @@ class DynamicBatchManager:
""" """
for req in finished_reqs: for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids) output = self.tokenizer.decode(req.output_ids)
yield output, req.request_id, req.output_metadata_list yield req.prompts + output
def clean_up(self): def clean_up(self):
# this logic should be implemented in the future. # this logic should be implemented in the future.
pass pass
async def generate(self,request_id,prompt_id,sampling_params): async def generate(self, request_id, prompt_id, sampling_params):
""" """
Generate the output of a request. Generate the output of a request.
""" """
self.add_input(request_id,prompt_id,sampling_params)
await self.add_input(request_id, prompt_id, sampling_params)
async def process_data(dbm):
async for data in dbm.loop_for_fwd():
print(data)
def start_dynamic_batching(args, tp_engine, waiting_req_list): def start_dynamic_batching(args, tp_engine, waiting_req_list):
@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
max_total_token_num=args.max_total_token_num, max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens, batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id, eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats, log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval, log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list, waiting_req_list=waiting_req_list,
) )
except Exception: except Exception:
batch_manager.clean_up() raise RuntimeError("Failed to start dynamic batching")
raise
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
asyncio.run(prod_task)
for item in batch_manager.loop_for_fwd():
print(item)
return batch_manager return batch_manager

@ -1,33 +0,0 @@
import asyncio
shared_list = []
async def producer():
for i in range(5):
await asyncio.sleep(1) # 模拟异步获取数据的操作
shared_list.append(i)
print(f"Produced {i}")
async def consumer():
last_index = 0
while True:
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
if last_index < len(shared_list):
item = shared_list[last_index]
print(f"Consumed {item}")
yield item
last_index += 1
async def main():
# 创建生产者和消费者任务
prod_task = asyncio.create_task(producer())
# 等待生产者任务完成
await prod_task
async for data in consumer():
print(data)
# 为了示例的目的,我们只等待一段时间,然后停止消费者
await asyncio.sleep(5)
asyncio.run(main())

@ -1,3 +1,6 @@
import asyncio
from dataclasses import dataclass
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
@ -5,10 +8,9 @@ from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai import colossalai
from dataclasses import dataclass
from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.manager import process_data, start_dynamic_batching
from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
@ -19,17 +21,26 @@ MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16 MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@dataclass @dataclass
class args: class args:
max_total_token_num: int max_total_token_num: int
batch_max_tokens: int batch_max_tokens: int
eos_id: int eos_id: int
model: str
disable_log_stats: bool disable_log_stats: bool
log_stats_interval: int log_stats_interval: int
def run(): def run():
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) arg = args(
max_total_token_num=42,
batch_max_tokens=42,
eos_id=0,
model="llama",
disable_log_stats=False,
log_stats_interval=10,
)
sampling_params = SamplingParams() sampling_params = SamplingParams()
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
@ -51,12 +62,14 @@ def run():
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
manager._set_tokenizer(tokenizer_name = model.__class__.__name__) asyncio.run(test(manager))
result_generator = manager.loop_for_fwd()
for result in result_generator:
print(result)
async def test(manager):
asyncio.create_task(process_data(manager))
await asyncio.sleep(5)
await manager.add_req(4, [0, 0, 10, 10, 10], SamplingParams())
await asyncio.sleep(5)
def check_dynamic_forward(rank, world_size, port): def check_dynamic_forward(rank, world_size, port):

Loading…
Cancel
Save