From fbf3c09e673794ed18c91d4bab1a7dfea052e95a Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 13 Oct 2023 11:01:18 +0800 Subject: [PATCH] [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> --- .../inference/dynamic_batching/io_struct.py | 15 +- colossalai/inference/manager.py | 139 ++++++++++-------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 29 +++- 4 files changed, 107 insertions(+), 109 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a..2028e320b 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ from .sampling_params import SamplingParams class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ class Req: self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -36,7 +37,11 @@ class Req: if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -102,7 +107,7 @@ class Batch: has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ @@ -111,9 +116,9 @@ class Batch: finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) else: - finished_req.append(req) + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} return finished_req diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7e..61276660d 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,7 @@ -import time -from typing import List import asyncio +from typing import List + +from transformers import AutoTokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,9 +10,9 @@ from .dynamic_batching.sampling_params import SamplingParams from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + class DynamicBatchManager: def __init__( self, @@ -19,6 +20,7 @@ class DynamicBatchManager: max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -30,6 +32,7 @@ class DynamicBatchManager: batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -45,32 +48,32 @@ class DynamicBatchManager: self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + self.model = model + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self._set_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): + async def add_input(self, request_id, sampling_params, prompts): """ Encode and Add new input to req queue. support one sequence input for now. """ - prompt_ids = self.tokenizer.encode(input_ids) + prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) + self.add_req(request_id, prompt_ids, sampling_params, prompts) return - + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -88,10 +91,15 @@ class DynamicBatchManager: The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list + # self.running_batch is not None or self.req_queue.waiting_req_list while True: - async for item in self._step(): - yield item + if self.running_batch is not None or self.req_queue.waiting_req_list: + async for result in self._step(): + yield result + else: + # need to wait for new requests + await asyncio.sleep(0.1) + continue counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -103,30 +111,33 @@ class DynamicBatchManager: ) self.stats_tool.print_stats() - if self.running_batch is None: - time.sleep(0.1) # 10ms - - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + def _set_tokenizer( + self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True + ): if tokenizer is not None: - self.tokenizer = tokenizer + self.tokenizer = tokenizer else: if "llama" in tokenizer_name.lower() and use_fast == True: print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) - def _step(self): + async def _step(self): """ Logic for handling requests """ @@ -136,14 +147,15 @@ class DynamicBatchManager: if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + async for item in self._prefill_batch(self.running_batch): + yield item self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +163,20 @@ class DynamicBatchManager: new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + async for item in self._prefill_batch(new_mini_batch): + yield item if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - + else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + async for item in self._decode_batch(self.running_batch): + yield item self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -187,7 +201,7 @@ class DynamicBatchManager: ) self.engine.cache[batch_id] = batch_data - def _prefill_batch(self, batch): + async def _prefill_batch(self, batch): """ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. """ @@ -198,11 +212,11 @@ class DynamicBatchManager: req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + async for item in self._handle_finish_req(batch, has_new_finished_req): + yield item # delete finished reqs - def _decode_batch(self, batch: Batch): + async def _decode_batch(self, batch: Batch): """ Decoding process """ @@ -210,7 +224,8 @@ class DynamicBatchManager: req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + async for item in self._handle_finish_req(batch, has_new_finished_req): + yield item def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -240,15 +255,15 @@ class DynamicBatchManager: batch.free_self() del batch - def _handle_finish_req(self, batch: Batch, has_new_finished_req): + async def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - + async for item in self._output_process(finished_reqs): + yield item def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -267,18 +282,24 @@ class DynamicBatchManager: """ for req in finished_reqs: output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list + yield req.prompts + output def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): + async def generate(self, request_id, prompt_id, sampling_params): """ Generate the output of a request. """ - self.add_input(request_id,prompt_id,sampling_params) - + + await self.add_input(request_id, prompt_id, sampling_params) + + +async def process_data(dbm): + async for data in dbm.loop_for_fwd(): + print(data) + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: @@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, eos_id=args.eos_id, + model=args.model, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, waiting_req_list=waiting_req_list, ) except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) + raise RuntimeError("Failed to start dynamic batching") return batch_manager diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36d..000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259..1b42e3a10 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -1,3 +1,6 @@ +import asyncio +from dataclasses import dataclass + import pytest import torch from packaging import version @@ -5,10 +8,9 @@ from transformers import LlamaForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig import colossalai -from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.manager import process_data, start_dynamic_batching from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -19,17 +21,26 @@ MAX_INPUT_LEN = 5 MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + @dataclass class args: max_total_token_num: int batch_max_tokens: int eos_id: int + model: str disable_log_stats: bool log_stats_interval: int def run(): - arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + arg = args( + max_total_token_num=42, + batch_max_tokens=42, + eos_id=0, + model="llama", + disable_log_stats=False, + log_stats_interval=10, + ) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -42,7 +53,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -51,12 +62,14 @@ def run(): infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) + asyncio.run(test(manager)) +async def test(manager): + asyncio.create_task(process_data(manager)) + await asyncio.sleep(5) + await manager.add_req(4, [0, 0, 10, 10, 10], SamplingParams()) + await asyncio.sleep(5) def check_dynamic_forward(rank, world_size, port):