From cf579ff46dc8d6e4c8f9b311134b693c2aeb0adc Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:52:19 +0800 Subject: [PATCH] [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li Co-authored-by: CjhHa1 --- colossalai/inference/async_engine.py | 133 +++++++ colossalai/inference/async_manager.py | 151 ++++++++ .../inference/dynamic_batching/__init__.py | 0 .../dynamic_batching/get_tokenizer.py | 40 ++ .../inference/dynamic_batching/infer_batch.py | 346 ++++++++++++++++++ .../inference/dynamic_batching/io_struct.py | 166 +++++++++ .../dynamic_batching/ray_dist_init.py | 152 ++++++++ .../dynamic_batching/ray_init_config.py | 58 +++ .../inference/dynamic_batching/req_queue.py | 73 ++++ .../dynamic_batching/sampling_params.py | 83 +++++ .../inference/dynamic_batching/stats.py | 45 +++ colossalai/inference/manager.py | 296 +++++++++++++++ .../quant/smoothquant/models/base_model.py | 1 - .../quant/smoothquant/models/llama.py | 27 +- .../inference/tensor_parallel/engine.py | 86 ++++- .../tensor_parallel/kvcache_manager.py | 4 +- .../tensor_parallel/modeling/bloom.py | 36 +- .../tensor_parallel/modeling/chatglm2.py | 17 +- .../tensor_parallel/modeling/llama.py | 49 +-- colossalai/kernel/triton/__init__.py | 1 - .../kernel/triton/copy_kv_cache_dest.py | 2 - requirements/requirements-test.txt | 2 + requirements/requirements.txt | 2 + tests/kit/model_zoo/transformers/llama.py | 6 +- tests/test_infer/test_chatglm2_infer.py | 1 - .../test_dynamic_batching/config.yaml | 14 + .../test_async_engine.py | 61 +++ .../test_dynamic_batching_manager.py | 95 +++++ .../test_offline_dynamic_batching.py | 84 +++++ .../test_dynamic_batching/test_ray_dist.py | 66 ++++ 30 files changed, 2005 insertions(+), 92 deletions(-) create mode 100644 colossalai/inference/async_engine.py create mode 100644 colossalai/inference/async_manager.py create mode 100644 colossalai/inference/dynamic_batching/__init__.py create mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py create mode 100644 colossalai/inference/dynamic_batching/infer_batch.py create mode 100644 colossalai/inference/dynamic_batching/io_struct.py create mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py create mode 100644 colossalai/inference/dynamic_batching/req_queue.py create mode 100644 colossalai/inference/dynamic_batching/sampling_params.py create mode 100644 colossalai/inference/dynamic_batching/stats.py create mode 100644 colossalai/inference/manager.py create mode 100644 tests/test_infer/test_dynamic_batching/config.yaml create mode 100644 tests/test_infer/test_dynamic_batching/test_async_engine.py create mode 100644 tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py create mode 100644 tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py create mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py new file mode 100644 index 000000000..d0890ba3e --- /dev/null +++ b/colossalai/inference/async_engine.py @@ -0,0 +1,133 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.io_struct import RequestOutput +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + self._requests.put_nowait(request_id) + self.new_requests_event.set() # NOTE: we may find a better way to clear this event + + def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ + self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() + + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration + return result + + +class Async_Engine: + + """ + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones + """ + + def __init__( + self, + router_config, + engine_config, + start_engine_loop: bool = True, + ) -> None: + self.driver = Driver(router_config=router_config, engine_config=engine_config) + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + def _step(self): + """ + Logic for handling requests + """ + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() + + def abort_request(self, request_id: str): + self.driver.abort(request_id) + + def _has_requests_in_progress(self): + return self.driver.is_running() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self._request_tracker.init_event() + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + if not self.is_running: + self.start_background_loop() + + await self.add_request(request_id, prompt, sampling_params) + + async for request_output in self._request_tracker: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self.abort_request(request_id) + raise e diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py new file mode 100644 index 000000000..60440a792 --- /dev/null +++ b/colossalai/inference/async_manager.py @@ -0,0 +1,151 @@ +from typing import List + +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .manager import DynamicBatchManager +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager(DynamicBatchManager): + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num: int, + batch_max_tokens: int, + model: str, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer, + eos_id, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) + + def _step(self): + """ + Logic for handling requests + """ + has_new_finished = False + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + else: + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + return None + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) + return outputs + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = Async_DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000..94aa3f243 --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,40 @@ +""" +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + +license: MIT, see LICENSE for more details. +""" + +from transformers import AutoTokenizer + +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000..112784c15 --- /dev/null +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,346 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import collections +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.inference.tensor_parallel import MemoryManager + + +# make batch infer state an attr of InferBatch +class InferSamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list: List[InferSamplingParams] + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + cache_manager: MemoryManager + max_total_len: int + + @classmethod + @torch.no_grad() + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, + ) -> "InferBatch": + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r["request_id"]] = i + + tokenized_input = r["input_id"] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=cache_manager, + max_total_len=max_total_len, + ) + + @torch.no_grad() + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ + remove_index = [] + for idx in range(len(self)): + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + @torch.no_grad() + def filter(self, request_ids: List[int]) -> "InferBatch": + """ + Filter finished batch and return a new InferBatch with left ones. + """ + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + cache_manager=self.cache_manager, + max_total_len=self.max_total_len, + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2) -> "InferBatch": + """ + Return megerd new InferBatch + """ + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts = [] + sampling_param_list = [] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + max_total_len = max(batch1.max_total_len, batch2.max_total_len) + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, + ) + + def __len__(self): + return len(self.requests) + + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [ + 0, + ] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000..fc5ecfe57 --- /dev/null +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -0,0 +1,166 @@ +# Adapted from https://github.com/ModelTC/lightllm + +from typing import Dict, List, Tuple + +from .sampling_params import SamplingParams + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + self.prompts = prompts + + def to_rpc_obj(self): + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } + + def stop_sequences_matched(self): + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): + return True + return False + + def __repr__(self): + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id, engine_max_output_len): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= engine_max_output_len: + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self) -> List[Req]: + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. + unfinished_req = [] + finished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + else: + finished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + def __len__(self): + return len(self.reqs) + + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000..70ef489d3 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,152 @@ +import logging +import os +from typing import List + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM + +import colossalai +from colossalai.inference.async_manager import start_dynamic_batching +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") + + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + # return final_outputs + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) + + def abort(self, request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + def is_running(self): + return self.start_dynamic_batching.is_running() + + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) + + def abort(self, request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + results = ray.get([w.step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) + + def is_running(self): + results = ray.get([w.is_running.remote() for w in self.workers]) + return any(results) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000..471f07330 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,58 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EngineArgsClass(BaseModel): + """Config for Engine""" + + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000..0de43bd1a --- /dev/null +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -0,0 +1,73 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import uuid +from typing import List + +import numpy as np + +from .io_struct import Batch, Req + + +class ReqQueue: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = waiting_req_list + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch: Batch): + if current_batch is not None: + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + # NOTE: change here < to <= + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + + def generate_new_batch(self, current_batch: Batch = None): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + flag = self._can_add_new_req(req) + if req.aborted: + aborted_count += 1 + continue + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000..a37a83390 --- /dev/null +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,83 @@ +# Adapted from https://github.com/ModelTC/lightllm + +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 256, + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + return ret diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py new file mode 100644 index 000000000..524072861 --- /dev/null +++ b/colossalai/inference/dynamic_batching/stats.py @@ -0,0 +1,45 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py new file mode 100644 index 000000000..9672a5014 --- /dev/null +++ b/colossalai/inference/manager.py @@ -0,0 +1,296 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time +from typing import List + +from .dynamic_batching.get_tokenizer import get_tokenizer +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + assert max_total_token_num >= self.engine.max_batch_size * ( + self.engine.max_input_len + self.engine.max_output_len + ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" + assert ( + batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len + ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + self.model = model + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer + if self.eos_id == None: + self.eos_id = self.tokenizer.eos_token_id + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + sampling_params.max_new_tokens = ( + self.engine.max_output_len + if sampling_params.max_new_tokens > self.engine.max_output_len + else sampling_params.max_new_tokens + ) + req = Req(request_id, prompt_ids, sampling_params, prompts) + self.req_queue.append(req) + return + + def add_input(self, request_id, prompts, sampling_params): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(request_id, prompt_ids, sampling_params, prompts) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list + while self.running_batch is not None or self.req_queue.waiting_req_list: + yield from self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + yield from self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + yield from self._output_process(finished_reqs) + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + + def clean_up(self): + # this logic should be implemented in the future. + pass + + def generate(self, request_id, prompts, sampling_params): + """ + Generate the output of a request. + """ + self.add_input(request_id, prompts, sampling_params) + return self.loop_for_fwd() + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 6a1d96ece..9554be9ea 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -87,7 +87,6 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): batch_infer_state.start_loc = seq_start_indexes.to("cuda") batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 - batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) batch_infer_state.cache_manager.free_all() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 4c3d6dcc0..30063857a 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -149,12 +149,6 @@ class LLamaSmoothquantAttention(nn.Module): self.k_rotary_output_scale.item(), ) - # NOTE might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) @@ -229,7 +223,7 @@ class LLamaSmoothquantAttention(nn.Module): infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) @@ -592,17 +586,13 @@ def llama_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - infer_state = self.infer_state + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + seq_length_with_past = seq_length + past_key_values_length # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -623,9 +613,7 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem @@ -713,6 +701,7 @@ def llama_model_forward( infer_state.is_context_stage = False infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 next_cache = next_decoder_cache if use_cache else None if not return_dict: diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 216b134f5..e410532d8 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -13,6 +13,8 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +# from dynamic_batching.infer_batch import InferBatch + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 _supported_models = [ @@ -61,7 +63,6 @@ class TPInferEngine: self.max_input_len = max_input_len self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model # This may change into an optional arg in the future assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" @@ -96,6 +97,8 @@ class TPInferEngine: self.shard_config = shard_config self.model = None + self.cache = {} + # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -284,7 +287,6 @@ class TPInferEngine: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -318,6 +320,7 @@ class TPInferEngine: batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state @torch.no_grad() @@ -381,6 +384,85 @@ class TPInferEngine: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + @torch.no_grad() + def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching Manager + """ + batch = self.cache.pop(batch_id) + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) + + batch_args = { + "batch_size": len(batch), + "max_len_in_batch": batch.nopad_max_len_in_batch, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, + } + + infer_state = BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + + setattr(model, "infer_state", infer_state) + output = self.model.forward(input_ids=input_) + logits = output.logits + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [ batch_size, 1 ] + + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) + ): + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(next_token_id) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this + self.cache[batch.batch_id] = batch + return output_dict + + @torch.no_grad() + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + @torch.no_grad() + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + # might want to create a sequence pool # add a single request/sequence/input text at a time and record its length # In other words, store the actual length of input tokens representing a single input text diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index c9e7aaae0..91bb96a1f 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -32,7 +32,7 @@ class MemoryManager: ): self.logger = logging.get_logger(__name__) self.available_size = size - self.past_key_values_length = 0 + self.max_len_in_batch = 0 self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) @@ -102,5 +102,5 @@ class MemoryManager: """free all memory by updating memory states""" self.available_size = len(self.mem_state) self.mem_state[:] = 1 - self.past_key_values_length = 0 + self.max_len_in_batch = 0 self.logger.info("freed all space of memory manager") diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 27a26caab..d84c567ea 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -133,17 +133,11 @@ class BloomInferenceForwards: assert hasattr(self, "infer_state") infer_state = self.infer_state - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - # if self.cache_manager.past_key_values_length > 0: - if infer_state.cache_manager.past_key_values_length > 0: - # update the past key values length in cache manager, - # NOTE use BatchInferState.past_key_values_length instead the one in cache manager - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length - # infer_state.cache_manager = self.cache_manager + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 if use_cache and seq_length != 1: # prefill stage @@ -160,21 +154,19 @@ class BloomInferenceForwards: infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) @@ -195,6 +187,7 @@ class BloomInferenceForwards: past_key_values_length=past_key_values_length, ) + infer_state.decode_layer_id = 0 for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -228,6 +221,7 @@ class BloomInferenceForwards: infer_state=infer_state, ) + infer_state.decode_layer_id += 1 hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) @@ -247,7 +241,7 @@ class BloomInferenceForwards: # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 - infer_state.decode_layer_id = 0 + infer_state.max_len_in_batch += 1 if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -453,9 +447,6 @@ class BloomInferenceForwards: mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 - if infer_state.is_context_stage: # context process max_input_len = q_length @@ -506,15 +497,12 @@ class BloomInferenceForwards: b_loc, b_start_loc, b_seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, alibi, ) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # update layer id - infer_state.decode_layer_id += 1 - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, # we create the past key value pair from the cache manager present = None diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index b8274d3c6..69a92c4fe 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( from ._utils import copy_kv_to_mem_cache try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + HAS_LIGHTLLM_KERNEL = True except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") @@ -118,13 +121,12 @@ class ChatGLM2InferenceForwards: else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = 0 + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length seq_length_with_past = seq_length + past_key_values_length - infer_state.seq_length_with_past = seq_length_with_past # prefill stage at first if use_cache and seq_length != 1: @@ -272,7 +274,6 @@ class ChatGLM2InferenceForwards: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 infer_state.max_len_in_batch += 1 - infer_state.cache_manager.past_key_values_length += seq_length if not return_dict: return tuple( @@ -487,7 +488,7 @@ class ChatGLM2InferenceForwards: attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), infer_state.start_loc, infer_state.seq_len, - infer_state.seq_length_with_past, + infer_state.max_len_in_batch, ) else: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a3937f6f1..a17b901dc 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -74,12 +74,11 @@ class LlamaInferenceForwards: output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_state = self.infer_state return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -90,15 +89,10 @@ class LlamaInferenceForwards: else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -118,23 +112,23 @@ class LlamaInferenceForwards: infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) + position_ids = position_ids.repeat(batch_size, 1) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -146,11 +140,12 @@ class LlamaInferenceForwards: infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item() + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -158,7 +153,7 @@ class LlamaInferenceForwards: # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -173,7 +168,6 @@ class LlamaInferenceForwards: next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None # NOTE: modify here for passing args to decoder layer @@ -197,8 +191,9 @@ class LlamaInferenceForwards: # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -224,7 +219,6 @@ class LlamaInferenceForwards: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -280,11 +274,8 @@ class LlamaInferenceForwards: # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) @@ -295,7 +286,6 @@ class LlamaInferenceForwards: if infer_state.is_context_stage: # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -304,7 +294,6 @@ class LlamaInferenceForwards: infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) if self.num_key_value_groups == 1: @@ -315,7 +304,7 @@ class LlamaInferenceForwards: attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: lightllm_llama2_context_attention_fwd( @@ -325,7 +314,7 @@ class LlamaInferenceForwards: attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: if infer_state.decode_is_contiguous: @@ -363,7 +352,7 @@ class LlamaInferenceForwards: infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: Llama2TokenAttentionForwards.token_attn( @@ -374,7 +363,7 @@ class LlamaInferenceForwards: infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, infer_state.other_kv_index, ) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 1fe292289..20da71d39 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,7 +2,6 @@ try: import triton HAS_TRITON = True - except ImportError: HAS_TRITON = False print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 0ce6b09e5..b8e6ab1d0 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -10,7 +10,6 @@ except ImportError: print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( @@ -53,7 +52,6 @@ if HAS_TRITON: assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( k_ptr, dest_index_ptr, diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610..f54b13c7e 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,6 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 19cb7a154..095617d76 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,6 +11,8 @@ ninja torch>=1.12 safetensors einops +pydantic +ray sentencepiece google protobuf diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 041de6b90..473064270 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,8 +27,10 @@ if HAS_LLAMA: # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + input_ids = torch.Tensor( + [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + ).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 399b70e14..f9f7670c4 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -52,7 +52,6 @@ def run_chatglm2_test(test_config): "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), } outputs = infer_engine.generate(input_tokens, **generate_kwargs) - assert outputs is not None diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000..0ac778a3c --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,14 @@ +engine_config: + model: MODEL_PATH + tensor_parallel_size: 1 + max_batch_size: 2 + max_input_len: 1024 + max_output_len: 512 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. +router_config: + max_total_token_num: 4096 + batch_max_tokens: 4096 + disable_log_stats: False + log_stats_interval: 10 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py new file mode 100644 index 000000000..512aa7430 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -0,0 +1,61 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.async_engine import Async_Engine +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_async_engine(path: str): + if not os.path.exists(path): + return + + config = RayInitConfig.from_yaml_path(path) + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + + prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" + sampling_params = SamplingParams() + asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) + + +async def get_result(engine, prompt, sampling_params): + request_id = str(uuid.uuid4().hex) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: + # print(result) + assert result is not None + + +async def asy_for_loop_test(config, prompt, sampling_params): + router_config = config.router_config_data + engine_config = config.engine_config_data + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + for i in range(10): + print("in for loop", i) + await get_result(engine, prompt, sampling_params) + + +def check_async_engine(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_async_engine(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_async_engine(): + spawn(check_async_engine, 1) + + +if __name__ == "__main__": + test_async_engine() diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py new file mode 100644 index 000000000..78df0d304 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -0,0 +1,95 @@ +import pytest +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import DynamicBatchManager +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 48 +MAX_OUTPUT_LEN = 256 + + +def run(): + sampling_params = SamplingParams() + + req1 = Req(0, [1], sampling_params) + req2 = Req(1, [2], sampling_params) + req3 = Req(2, [3], sampling_params) + # req 1-3 are initiliazed as token forward requests + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + + # init model and tp engine + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + dynamic_batch_manager = DynamicBatchManager( + tp_engine=infer_engine, + max_total_token_num=640, + batch_max_tokens=608, + eos_id=0, + log_stats=False, + log_stats_interval=10, + waiting_req_list=waiting_list, + model="llama", + ) + before_add = len(dynamic_batch_manager.req_queue) + + # test add req function + dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params) + assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 + + # test abort function + dynamic_batch_manager.abort(req4.request_id) + assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested + batch = dynamic_batch_manager.req_queue.generate_new_batch() + assert len(batch) == 2 + + dynamic_batch_manager._init_batch(batch) + assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None + + batch.reqs[0].has_generate_finished = True + # filter one finished + batch.filter_finished() + dynamic_batch_manager._filter_batch(batch) + assert len(dynamic_batch_manager.engine.cache) == 1 + + # test merge batch + new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) + assert len(new_batch) == 1 + dynamic_batch_manager._init_batch(new_batch) + dynamic_batch_manager._merge_batch(batch, new_batch) + + assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 + + +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) + + +if __name__ == "__main__": + test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py new file mode 100644 index 000000000..9925a80b6 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +MAX_BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@dataclass +class args: + max_total_token_num: int + batch_max_tokens: int + model: str + eos_id: int + disable_log_stats: bool + log_stats_interval: int + + +def run(): + arg = args( + max_total_token_num=42, + model="llama", + batch_max_tokens=42, + eos_id=0, + disable_log_stats=False, + log_stats_interval=10, + ) + sampling_params = SamplingParams() + + req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) + req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) + + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + waiting_list.append(req4) + + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params) + for result in ans_gen: + assert result is not None + + +def check_dynamic_forward(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching(): + spawn(check_dynamic_forward, TP_SIZE) + + +if __name__ == "__main__": + test_dynamic_batching() diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000..a840407d5 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,66 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_ray_dist(path: str): + if not os.path.exists(path): + return + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" + + request_id = str(uuid.uuid4().hex) + sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) + + async def get_result(request_id, prompt, sampling_params): + return await driver.async_generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) + assert result is not None + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + assert result is not None + print("result: ", result) + + is_running = None + is_running = driver.is_running() + assert is_running is not None + print("is_running: ", is_running) + + +def check_ray_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_ray_dist(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_ray_dist(): + spawn(check_ray_dist, 1) + + +if __name__ == "__main__": + test_ray_dist()