From 69cd7e069d5705c7e431b301ac14924711c74e41 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:47:36 +0800 Subject: [PATCH 1/7] [Inference] ADD async and sync Api server using FastAPI (#5396) * add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template --- colossalai/inference/batch_bucket.py | 3 + colossalai/inference/config.py | 19 +- colossalai/inference/core/async_engine.py | 318 ++++++++++++++++++ colossalai/inference/core/engine.py | 24 +- colossalai/inference/core/request_handler.py | 34 +- colossalai/inference/server/__init__.py | 0 colossalai/inference/server/api_server.py | 200 +++++++++++ .../inference/server/completion_service.py | 35 ++ colossalai/inference/server/utils.py | 16 + colossalai/inference/struct.py | 1 + colossalai/shardformer/shard/shardformer.py | 7 +- .../test_async_engine/test_async_engine.py | 80 +++++ .../test_async_engine/test_request_tracker.py | 77 +++++ 13 files changed, 789 insertions(+), 25 deletions(-) create mode 100644 colossalai/inference/core/async_engine.py create mode 100644 colossalai/inference/server/__init__.py create mode 100644 colossalai/inference/server/api_server.py create mode 100644 colossalai/inference/server/completion_service.py create mode 100644 colossalai/inference/server/utils.py create mode 100644 tests/test_infer/test_async_engine/test_async_engine.py create mode 100644 tests/test_infer/test_async_engine/test_request_tracker.py diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 726dfd614..8cc9eebaa 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -62,6 +62,9 @@ class BatchBucket: def current_batch_size(self): return self._current_batch_size + def __len__(self): + return self._current_batch_size + @property def available_batch_size(self): return self.max_batch_size - self._current_batch_size diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index a68400fb0..421c6b589 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,10 +1,10 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ - +import dataclasses import logging from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.distributed as dist @@ -214,3 +214,18 @@ class InferenceConfig: meta_config[type] = getattr(model_config, type) return GenerationConfig.from_dict(meta_config) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + inference_config_args = {} + for attr in attrs: + if attr in config_dict: + inference_config_args[attr] = config_dict[attr] + else: + inference_config_args[attr] = getattr(cls, attr) + + # Set the attributes from the parsed arguments. + inference_config = cls(**inference_config_args) + return inference_config diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py new file mode 100644 index 000000000..5be36fada --- /dev/null +++ b/colossalai/inference/core/async_engine.py @@ -0,0 +1,318 @@ +import asyncio +from functools import partial +from logging import Logger +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type + +from colossalai.inference.core.engine import InferenceEngine + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: + msg = "Task finished unexpectedly. This should never happen! " + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc + raise AsyncEngineDeadError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + +class AsyncStream: + """A stream of Output for a request that can be + iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopIteration) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self): + result = await self._queue.get() + if result is StopIteration: + raise StopAsyncIteration + elif isinstance(result, Exception): + raise result + return result + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[int] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._request_streams + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None: + """ + Propagate an exception to request streams (all if request_id is None). + """ + if request_id is not None: + self._request_streams[request_id].put(exc) + else: + for stream in self._request_streams.values(): + stream.put(exc) + + def process_finished_request(self, finished_request) -> None: + """Process a finished request from the engine.""" + request_id = finished_request.request_id + + self._request_streams[request_id].put(finished_request) + self.abort_request(request_id) + + def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream: + """ + Add a request to be sent to the engine on the next background + loop iteration. + """ + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) + + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: int, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + Logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if request_id not in self._request_streams or self._request_streams[request_id].finished: + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_requests(self): + """ + Get new requests from http server. + """ + new_requests: List[Dict] = [] + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests + + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class _AsyncInferenceEngine(InferenceEngine): + """ + Async methods for Inference Engine. + """ + + async def async_step(self) -> List[str]: + """ + The async version of Engine.step() + Performs one decoding iteration and returns newly generated results. + + It first schedules the sequences to be executed in the next iteration. + Then, it executes the model and updates the scheduler with the model + outputs. Finally, it decodes the sequences and returns the newly + generated results. + """ + batch = self.request_handler.schedule() + loop = asyncio.get_running_loop() + + # Use run_in_executor to asyncally run the sync method model.forward(). + logits = await loop.run_in_executor( + None, + self.model, + batch, + self.k_cache, + self.v_cache, + ) + + if self.inference_config.pad_input: + logits = logits[:, -1, :] + self.request_handler.search_tokens(self.generation_config, logits) + # Return: List[Sequence] + finished_sequences = self.request_handler.update() + + return finished_sequences, self.request_handler.current_requests_in_batch() > 0 + + +class AsyncInferenceEngine: + """An asynchronous wrapper for LLMEngine. + + This class is used to wrap the InferenceEngine class to make it asynchronous. + It uses asyncio to create a background loop that keeps processing incoming + requests. The LLMEngine is kicked by the generate method when there are + requests in the waiting queue. The generate method yields the outputs + from the InferenceEngine to the caller. + """ + + _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine + + def __init__(self, start_engine_loop: bool = True, **kwargs): + self.engine = self._init_engine(**kwargs) + self.background_loop = None + # reference to the unshielded loop + self._background_loop_unshielded = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + @property + def background_loop_status(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.background_loop_status: + raise RuntimeError("Existing loop is running") + + self._request_tracker.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) + self._background_loop_unshielded.add_done_callback( + partial(_raise_exception_on_finish, request_tracker=self._request_tracker) + ) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def _init_engine(self, **kwargs): + return self._engine_class(**kwargs) + + async def step(self): + """ + Run engine to process requests + + Returns True if there are in-progress requests. + """ + new_requests = self._request_tracker.get_new_requests() + for new_request in new_requests: + self.engine.add_single_request(**new_request) + newly_finished_seqs, has_running_requests = await self.engine.async_step() + for seq in newly_finished_seqs: + self._request_tracker.process_finished_request(seq) + + return has_running_requests + + async def _engine_abort(self, request_ids: Iterable[int]): + self.engine.abort_request(request_ids) + + async def abort(self, request_id: int): + """ + Abort a single request + """ + if not self.background_loop_status: + raise RuntimeError("Background loop is not running or launched correctly.") + return self._abort(request_id) + + def _abort(self, request_id: int): + self._request_tracker.abort_request(request_id) + + async def run_engine_loop(self): + processing_requests = False + while True: + if not processing_requests: + await self._request_tracker.wait_for_new_requests() + processing_requests = await self.step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncStream: + """ + Add a request to the background tracker(waitting queue), start the background loop if needed. + """ + if not self.background_loop_status: + if self.start_engine_loop: + self.start_background_loop() + else: + raise RuntimeError("Background loop is not running.") + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + ) + return stream + + async def generate( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncIterator[str]: + """ + Generate output from a request. It receives the request from http server, adds it into the + waitting queue of Async Engine and streams the output sequence. + + """ + try: + stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) + async for request_output in stream: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self._abort(request_id) + raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 04eb620c5..eb5a825d2 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,6 @@ import time from itertools import count -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Iterable import numpy as np import torch @@ -507,9 +507,9 @@ class InferenceEngine: def generate( self, - prompts: List[str] = None, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - request_ids: List[int] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, ) -> List[str]: @@ -527,6 +527,11 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): + + if isinstance(prompts, str) and isinstance(request_ids, int): + prompts = [prompts] + request_ids = [request_ids] + if prompts is not None or prompts_token_ids is not None: gen_config_dict = generation_config.to_dict() if generation_config is not None else {} self.add_request( @@ -535,7 +540,7 @@ class InferenceEngine: prompts_token_ids=prompts_token_ids, **gen_config_dict, ) - + output_seqs_list = [] total_tokens_list = [] @@ -580,13 +585,13 @@ class InferenceEngine: if isinstance(prompts, (list, tuple)): return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] elif isinstance(prompts, str): - return self.inference_config.rompt_template.format(input_text=prompts) + return self.inference_config.prompt_template.format(input_text=prompts) else: raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") def add_request( self, - request_ids: List[int] = None, + request_ids: Union[List[int], int] = None, prompts: List[str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, **kwargs, @@ -601,6 +606,7 @@ class InferenceEngine: """ # apply the prompt template to the input prompts + if self.has_prompt_template and prompts is not None: prompts = self.format_prompt(prompts) @@ -614,6 +620,7 @@ class InferenceEngine: prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ "input_ids" ] + print(prompts_token_ids) if isinstance(prompts_token_ids, list): pass @@ -632,8 +639,6 @@ class InferenceEngine: for i in range(prompts_num): if request_ids: - if not isinstance(request_ids, list): - request_ids = [request_ids] assert isinstance( request_ids[0], int ), f"The request_id type must be int, but got {type(request_ids[0])}" @@ -734,6 +739,9 @@ class InferenceEngine: next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) + print("in step", logits) + + self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 10180ff2f..6837a80c5 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -263,24 +263,27 @@ class RequestHandler: ), f"Sequence {req.request_id} exceeds input length limit" self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req) - def abort_sequence(self, request_id: str): + def abort_sequence(self, request_id: int): """ Abort the request. """ - seq, priority = self._find_sequence(request_id) - if seq.status == RequestStatus.WAITING: - seq.mark_aborted() - self.waiting_list[priority].remove(seq) - elif seq.status.is_running(): - self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) - self.running_list.remove(seq) - else: - try: - self.done_list.remove(seq) - except: - return + result = self._find_sequence(request_id) + if result is not None: + seq, priority = result + if seq.status == RequestStatus.WAITING: + seq.mark_aborted() + self.waiting_list[priority].remove(seq) + elif seq.status.is_running(): + self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) + self.running_list.remove(seq) + else: + try: + self.done_list.remove(seq) + except: + return + return - def _find_sequence(self, request_id: str) -> Sequence: + def _find_sequence(self, request_id: int) -> Sequence: """ Find the request by request_id. """ @@ -324,6 +327,9 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() + def current_requests_in_batch(self) -> int: + return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size + def search_tokens(self, generation_config: GenerationConfig, logits): """ Sample tokens for finished requests. diff --git a/colossalai/inference/server/__init__.py b/colossalai/inference/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py new file mode 100644 index 000000000..c182c5160 --- /dev/null +++ b/colossalai/inference/server/api_server.py @@ -0,0 +1,200 @@ +""" +Doc: + Feature: + - FastAPI based http server for Colossal-Inference + - Completion Service Supported + Usage: (for local user) + - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` + - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hello, who are you? ","stream":"False"}'` +""" + + +import argparse +import json + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.inference.config import InferenceConfig +from colossalai.inference.server.completion_service import CompletionServing +from colossalai.inference.server.utils import id_generator + +from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None +supported_models_dict = {"Llama_Models": ("llama2-7b",)} +prompt_template_choices = ["llama", "vicuna"] + + +@app.get("/v0/models") +def get_available_models() -> Response: + return JSONResponse(supported_models_dict) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + A request should be a JSON object with the following fields: + - prompts: the prompts to use for the generation. + - stream: whether to stream the results or not. + - other fields: + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", None) + + request_id = id_generator() + generation_config = get_generation_config(request_dict) + results = engine.generate(request_id, prompt, generation_config=generation_config) + + # Streaming case + def stream_results(): + for request_output in results: + ret = {"text": request_output} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + for request_output in results: + if request.is_disconnected(): + # Abort the request if the client disconnects. + engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + ret = {"text": final_output} + return JSONResponse(ret) + + +@app.post("/v1/completion") +async def create_completion(request: Request): + request_dict = await request.json() + generation_config = get_generation_config(request_dict) + generator = await completion_serving.create_completion(request, generation_config) + output = tokenizer.decode(generator.output_token_id) + ret = {"request_id": generator.request_id, "text": output} + return ret + + +def get_generation_config(request): + generation_config = async_engine.engine.generation_config + for arg in request: + if hasattr(generation_config, arg): + generation_config[arg] = request[arg] + return generation_config + + +def add_engine_config(parser): + parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use") + + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, " "will be automatically derived from the model.", + ) + # Parallel arguments + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU", + ) + + parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages") + + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") + + # KV cache arguments + parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") + + parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") + + # generation arguments + parser.add_argument( + "--prompt_template", + choices=prompt_template_choices, + default=None, + help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", + ) + + # Quantization settings. + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", "gptq", "squeezellm", None], + default=None, + help="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Always use eager-mode PyTorch. If False, " + "will use eager mode and CUDA graph in hybrid " + "for maximal performance and flexibility.", + ) + return parser + + +def parse_args(): + parser = argparse.ArgumentParser(description="Colossal-Inference API server.") + + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.", + ) + parser = add_engine_config(parser) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + inference_config = InferenceConfig.from_dict(vars(args)) + model = AutoModelForCausalLM.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) + async_engine = AsyncInferenceEngine( + start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config + ) + engine = async_engine.engine + completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) + + app.root_path = args.root_path + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py new file mode 100644 index 000000000..bb2160009 --- /dev/null +++ b/colossalai/inference/server/completion_service.py @@ -0,0 +1,35 @@ +import asyncio + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import id_generator + + +class CompletionServing: + def __init__(self, engine: AsyncInferenceEngine, served_model: str): + self.engine = engine + self.served_model = served_model + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_completion(self, request, generation_config): + request_dict = await request.json() + request_id = id_generator() + prompt = request_dict.pop("prompt") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt) + + final_res = None + async for res in result_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return {"error_msg": "Client disconnected"} + final_res = res + + return final_res diff --git a/colossalai/inference/server/utils.py b/colossalai/inference/server/utils.py new file mode 100644 index 000000000..c10826f73 --- /dev/null +++ b/colossalai/inference/server/utils.py @@ -0,0 +1,16 @@ +# make it singleton +class NumericIDGenerator: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NumericIDGenerator, cls).__new__(cls) + cls._instance.current_id = 0 + return cls._instance + + def __call__(self): + self.current_id += 1 + return self.current_id + + +id_generator = NumericIDGenerator() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index db4820f51..334a39b4e 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -164,6 +164,7 @@ class Sequence: return ( f"(request_id={self.request_id}, " f"prompt={self.prompt}, " + f"output_token_id={self.output_token_id}," f"status={self.status.name}, " f"sample_params={self.sample_params}, " f"input_len={self.input_len}," diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b3991c4f0..b54c58273 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,6 +1,7 @@ import os from typing import Dict, List, Tuple +import torch.distributed as dist import torch.nn as nn from torch import Tensor @@ -36,7 +37,11 @@ class ShardFormer: """ def __init__(self, shard_config: ShardConfig): - self.coordinator = DistCoordinator() + self.is_distributed = dist.is_initialized() + if self.is_distributed: + self.coordinator = DistCoordinator() + else: + self.coordinator = None self.shard_config = shard_config def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py new file mode 100644 index 000000000..ebca11c72 --- /dev/null +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -0,0 +1,80 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + + +@dataclass +class SequenceTpye: + request_id: int + + +class MockEngine: + def __init__(self): + self.step_calls = 0 + self.add_request_calls = 0 + self.abort_request_calls = 0 + self.request_id = None + + async def async_step(self): + self.step_calls += 1 + return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] + + def generate(self, request_id): + self.request_id = request_id + + def stop_generating(self): + self.request_id = None + + def add_request(self, **kwargs): + del kwargs # Unused + self.add_request_calls += 1 + + def abort_request(self, request_id): + del request_id # Unused + self.abort_request_calls += 1 + + +class MockAsyncLLMEngine(AsyncInferenceEngine): + def _init_engine(self, *args, **kwargs): + return MockEngine() + + +@pytest.mark.asyncio +async def test_new_requests_event(): + engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine.start_background_loop() + await asyncio.sleep(0.01) + assert engine.engine.step_calls == 0 + + await engine.add_request(1, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 1 + assert engine.engine.step_calls == 1 + + await engine.add_request(2, "", None) + engine.engine.generate(2) + await asyncio.sleep(0) + assert engine.engine.add_request_calls == 2 + assert engine.engine.step_calls == 2 + await asyncio.sleep(0) + assert engine.engine.step_calls == 3 + engine.engine.stop_generating() + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + + await engine.add_request(3, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + + +if __name__ == "__main__": + test_new_requests_event() diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracker.py new file mode 100644 index 000000000..9a797a862 --- /dev/null +++ b/tests/test_infer/test_async_engine/test_request_tracker.py @@ -0,0 +1,77 @@ +import pytest + +from colossalai.inference.core.async_engine import RequestTracker +from colossalai.inference.struct import Sequence + + +class SampleEvent: + def __init__(self): + self.flag = False + + def set(self): + self.flag = True + + def clear(self): + self.flag = False + + +def test_request_tracker(): + tracker = RequestTracker() + tracker.new_requests_event = SampleEvent() + stream_1 = tracker.add_request(1) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 1 + assert new[0]["request_id"] == 1 + assert not finished + assert not stream_1.finished + + stream_2 = tracker.add_request(2) + stream_3 = tracker.add_request(3) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 2 + assert new[0]["request_id"] == 2 + assert new[1]["request_id"] == 3 + assert not finished + assert not stream_2.finished + assert not stream_3.finished + + # request_ids must be unique + with pytest.raises(KeyError): + tracker.add_request(1) + assert not tracker.new_requests_event.flag + + tracker.abort_request(1) + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert 1 in finished + assert not new + assert stream_1.finished + + stream_4 = tracker.add_request(4) + tracker.abort_request(4) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert 4 in finished + assert not new + assert stream_4.finished + + stream_5 = tracker.add_request(5) + assert tracker.new_requests_event.flag + tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(finished) == 1 + assert 2 in finished + assert len(new) == 1 + assert new[0]["request_id"] == 5 + assert stream_2.finished + assert not stream_5.finished + + +if __name__ == "__main__": + test_request_tracker() From de378cd2abd77b464786dc5f8298c9edbf023fbc Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:06:05 +0800 Subject: [PATCH 2/7] [Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432) * finish online test and add examples * fix test_contionus_batching * fix some bugs * fix bash * fix * fix inference * finish revision * fix typos * revision --- colossalai/inference/core/async_engine.py | 125 +++++++----------- colossalai/inference/core/engine.py | 6 +- colossalai/inference/core/request_handler.py | 1 + colossalai/inference/server/api_server.py | 16 ++- .../inference/server/completion_service.py | 13 +- colossalai/inference/struct.py | 2 + .../kernel/triton/no_pad_rotary_embedding.py | 2 + examples/inference/client/locustfile.py | 30 +++++ examples/inference/client/run_locust.sh | 24 ++++ tests/test_infer/test_continuous_batching.py | 89 +++++++++++++ 10 files changed, 214 insertions(+), 94 deletions(-) create mode 100644 examples/inference/client/locustfile.py create mode 100644 examples/inference/client/run_locust.sh create mode 100644 tests/test_infer/test_continuous_batching.py diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index 5be36fada..e23d0b90f 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -1,13 +1,13 @@ import asyncio +import logging from functools import partial -from logging import Logger -from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type +from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type from colossalai.inference.core.engine import InferenceEngine - -class AsyncEngineDeadError(RuntimeError): - pass +# CLI logger +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("colossalai-inference") def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: @@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac except asyncio.CancelledError: return except Exception as exc: - raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc - raise AsyncEngineDeadError(msg) + raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc + raise RuntimeError(msg) except Exception as exc: request_tracker.propagate_exception(exc) raise exc -class AsyncStream: +class RequstStream: """A stream of Output for a request that can be iterated over asynchronously.""" - def __init__(self, request_id: str) -> None: + def __init__(self, request_id: int) -> None: self.request_id = request_id - self._queue = asyncio.Queue() - self._finished = False + self._future = asyncio.Future() - def put(self, item) -> None: - if self._finished: - return - self._queue.put_nowait(item) + def set_result(self, result) -> None: + """Set final result and signal taht it's ready""" + if not self._future.done(): + self._future.set_result(result) - def finish(self) -> None: - self._queue.put_nowait(StopIteration) - self._finished = True + async def get_result(self): + """Wait for the result to be set and return it.""" + return await self._future @property def finished(self) -> bool: - return self._finished - - def __aiter__(self): - return self - - async def __anext__(self): - result = await self._queue.get() - if result is StopIteration: - raise StopAsyncIteration - elif isinstance(result, Exception): - raise result - return result + """Check if the stream has finished by checking if the future is done.""" + return self._future.done() -class RequestTracker: - """Synchronous abstraction for tracking requests.""" +class Tracer: + """ + Recording new requests and finished requests. + """ def __init__(self) -> None: - self._request_streams: Dict[str, AsyncStream] = {} + self._request_streams: Dict[int, RequstStream] = {} self._finished_requests: asyncio.Queue[int] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue() self.new_requests_event = None def __contains__(self, item): @@ -79,19 +70,21 @@ class RequestTracker: Propagate an exception to request streams (all if request_id is None). """ if request_id is not None: - self._request_streams[request_id].put(exc) + self._request_streams[request_id].set_result(exc) else: for stream in self._request_streams.values(): - stream.put(exc) + stream.set_result(exc) def process_finished_request(self, finished_request) -> None: """Process a finished request from the engine.""" request_id = finished_request.request_id - - self._request_streams[request_id].put(finished_request) + try: + self._request_streams[request_id].set_result(finished_request) + except: + raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check") self.abort_request(request_id) - def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream: + def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream: """ Add a request to be sent to the engine on the next background loop iteration. @@ -99,7 +92,7 @@ class RequestTracker: if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") - stream = AsyncStream(request_id) + stream = RequstStream(request_id) self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) self.new_requests_event.set() @@ -109,7 +102,7 @@ class RequestTracker: def abort_request(self, request_id: int, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: - Logger.info(f"Aborted request {request_id}.") + logger.info(f"Aborted request {request_id}.") self._finished_requests.put_nowait(request_id) @@ -117,7 +110,7 @@ class RequestTracker: # The request has already finished or been aborted. return - self._request_streams[request_id].finish() + self._request_streams[request_id].set_result(None) def get_new_requests(self): """ @@ -134,30 +127,6 @@ class RequestTracker: return new_requests - def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]: - """Get the new requests and finished requests to be - sent to the engine.""" - new_requests: List[Dict] = [] - finished_requests: Set[int] = set() - - while not self._finished_requests.empty(): - request_id = self._finished_requests.get_nowait() - finished_requests.add(request_id) - self._request_streams.pop(request_id, None) - - while not self._new_requests.empty(): - stream, new_request = self._new_requests.get_nowait() - if stream.request_id in finished_requests: - # The request has already been aborted. - stream.finish() - continue - self._request_streams[stream.request_id] = stream - new_requests.append(new_request) - - self.new_requests_event.clear() - - return new_requests, finished_requests - async def wait_for_new_requests(self): await self.new_requests_event.wait() @@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine): self.request_handler.search_tokens(self.generation_config, logits) # Return: List[Sequence] finished_sequences = self.request_handler.update() + for sequence in finished_sequences: + sequence.output = self.tokenizer.decode(sequence.output_token_id) return finished_sequences, self.request_handler.current_requests_in_batch() > 0 @@ -216,7 +187,7 @@ class AsyncInferenceEngine: # reference to the unshielded loop self._background_loop_unshielded = None self.start_engine_loop = start_engine_loop - self._request_tracker = RequestTracker() + self._request_tracer = Tracer() @property def background_loop_status(self): @@ -226,11 +197,11 @@ class AsyncInferenceEngine: if self.background_loop_status: raise RuntimeError("Existing loop is running") - self._request_tracker.init_event() + self._request_tracer.init_event() self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) self._background_loop_unshielded.add_done_callback( - partial(_raise_exception_on_finish, request_tracker=self._request_tracker) + partial(_raise_exception_on_finish, request_tracker=self._request_tracer) ) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -243,12 +214,13 @@ class AsyncInferenceEngine: Returns True if there are in-progress requests. """ - new_requests = self._request_tracker.get_new_requests() + new_requests = self._request_tracer.get_new_requests() for new_request in new_requests: self.engine.add_single_request(**new_request) newly_finished_seqs, has_running_requests = await self.engine.async_step() + for seq in newly_finished_seqs: - self._request_tracker.process_finished_request(seq) + self._request_tracer.process_finished_request(seq) return has_running_requests @@ -264,13 +236,13 @@ class AsyncInferenceEngine: return self._abort(request_id) def _abort(self, request_id: int): - self._request_tracker.abort_request(request_id) + self._request_tracer.abort_request(request_id) async def run_engine_loop(self): processing_requests = False while True: if not processing_requests: - await self._request_tracker.wait_for_new_requests() + await self._request_tracer.wait_for_new_requests() processing_requests = await self.step() await asyncio.sleep(0) @@ -279,7 +251,7 @@ class AsyncInferenceEngine: request_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]] = None, - ) -> AsyncStream: + ) -> RequstStream: """ Add a request to the background tracker(waitting queue), start the background loop if needed. """ @@ -288,7 +260,7 @@ class AsyncInferenceEngine: self.start_background_loop() else: raise RuntimeError("Background loop is not running.") - stream = self._request_tracker.add_request( + stream = self._request_tracer.add_request( request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, @@ -308,8 +280,7 @@ class AsyncInferenceEngine: """ try: stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) - async for request_output in stream: - yield request_output + return await stream.get_result() except (Exception, asyncio.CancelledError) as e: # If there is an exception or coroutine is cancelled, abort the diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index eb5a825d2..635c3f801 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -620,10 +620,10 @@ class InferenceEngine: prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ "input_ids" ] - print(prompts_token_ids) if isinstance(prompts_token_ids, list): - pass + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: @@ -739,8 +739,6 @@ class InferenceEngine: next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) - print("in step", logits) - self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 6837a80c5..12c9cebf7 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -209,6 +209,7 @@ class RequestHandler: break num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) + # for now the recycle logic is not working remove_list.extend(lst[:num_seqs_to_add]) self.running_list.extend(lst[:num_seqs_to_add]) diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index c182c5160..1d3a6b497 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -58,7 +58,7 @@ async def generate(request: Request) -> Response: # Streaming case def stream_results(): for request_output in results: - ret = {"text": request_output} + ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -71,7 +71,7 @@ async def generate(request: Request) -> Response: # Abort the request if the client disconnects. engine.abort(request_id) return Response(status_code=499) - final_output = request_output + final_output = request_output[len(prompt) :] assert final_output is not None ret = {"text": final_output} @@ -81,11 +81,15 @@ async def generate(request: Request) -> Response: @app.post("/v1/completion") async def create_completion(request: Request): request_dict = await request.json() + stream = request_dict.pop("stream", False) generation_config = get_generation_config(request_dict) - generator = await completion_serving.create_completion(request, generation_config) - output = tokenizer.decode(generator.output_token_id) - ret = {"request_id": generator.request_id, "text": output} - return ret + result = await completion_serving.create_completion(request, generation_config) + + ret = {"request_id": result.request_id, "text": result.output} + if stream: + return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") + else: + return JSONResponse(content=ret) def get_generation_config(request): diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py index bb2160009..61833b031 100644 --- a/colossalai/inference/server/completion_service.py +++ b/colossalai/inference/server/completion_service.py @@ -18,18 +18,17 @@ class CompletionServing: async def create_completion(self, request, generation_config): request_dict = await request.json() request_id = id_generator() + prompt = request_dict.pop("prompt") # it is not a intuitive way self.engine.engine.generation_config = generation_config result_generator = self.engine.generate(request_id, prompt=prompt) - final_res = None - async for res in result_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(request_id) - return {"error_msg": "Client disconnected"} - final_res = res + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + raise RuntimeError("Client disconnected") + final_res = await result_generator return final_res diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 334a39b4e..216dfd1eb 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -61,6 +61,7 @@ class Sequence: pad_token_id (int): The pad token id for this inference process. max_output_len (int): Maximum output length. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. + output(str): The output of sequence """ request_id: int @@ -73,6 +74,7 @@ class Sequence: max_output_len: int = 256 # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future. ignore_eos: bool = False + output: str = None def __post_init__(self): self.output_token_id = [] diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index e0da816bd..3a1de6d6a 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -598,6 +598,8 @@ def decoding_fused_rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) + assert k.size(1) == v.size(1) + assert k_cache.size(-1) == v_cache.size(-1) if head_dim >= 512: num_warps = 16 diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py new file mode 100644 index 000000000..7402a9c04 --- /dev/null +++ b/examples/inference/client/locustfile.py @@ -0,0 +1,30 @@ +from locust import HttpUser, between, tag, task + + +class QuickstartUser(HttpUser): + wait_time = between(1, 5) + + @tag("online-generation") + @task(5) + def completion(self): + self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) + + @tag("online-generation") + @task(5) + def completion_streaming(self): + self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + + @tag("offline-generation") + @task(5) + def generate_stream(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"}) + + @tag("offline-generation") + @task(5) + def generate(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"}) + + @tag("online-generation", "offline-generation") + @task + def get_models(self): + self.client.get("/v0/models") diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh new file mode 100644 index 000000000..31f4c962e --- /dev/null +++ b/examples/inference/client/run_locust.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +#argument1: model_path + +# launch server +model_path=${1:-"lmsys/vicuna-7b-v1.3"} +echo "Model Path: $model_path" +echo "Starting server..." +python -m colossalai.inference.server.api_server --model $model_path & +SERVER_PID=$! + +# waiting time +sleep 60 + +# Run Locust +echo "Starting Locust..." +echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." +locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 + +# kill Server +echo "Stopping server..." +kill $SERVER_PID + +echo "Test and server shutdown completely" diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py new file mode 100644 index 000000000..0b0d92c7c --- /dev/null +++ b/tests/test_infer/test_continuous_batching.py @@ -0,0 +1,89 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def generate_inputs(num_sequences, min_length, max_length): + sequences = [] + for _ in range(num_sequences): + length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item() + # generating randomly lengthed sequences + sequence = torch.randint(10, 30000, size=(length,)) + sequences.append(sequence) + return sequences + + +@parameterize( + "max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 +) +def check_inference_engine(use_engine=False, prompt_template=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() + model = model.eval() + + inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len) + + if use_engine: + inference_config = InferenceConfig( + max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == max_output_len + inference_engine.add_request(prompts_token_ids=inputs_token_ids) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=max_output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert len(outputs) == 10 * max_batch_size + + +@parameterize("prompt_template", [None, "llama"]) +def check_continuous_batching(prompt_template): + check_inference_engine(use_engine=True, prompt_template=prompt_template) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_continuous_batching() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_continuous_batching(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_continuous_batching() From c06403286567f62cb0a6dfc5e075cf60e291cea9 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Sun, 7 Apr 2024 14:45:43 +0800 Subject: [PATCH 3/7] [Online Server] Chat Api for streaming and not streaming response (#5470) * fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n --- colossalai/inference/server/api_server.py | 54 ++++++-- colossalai/inference/server/chat_service.py | 142 ++++++++++++++++++++ colossalai/inference/server/utils.py | 20 +++ colossalai/inference/struct.py | 13 +- examples/inference/client/locustfile.py | 30 ++++- examples/inference/client/run_locust.sh | 7 +- tests/test_infer/test_server.py | 79 +++++++++++ 7 files changed, 326 insertions(+), 19 deletions(-) create mode 100644 colossalai/inference/server/chat_service.py create mode 100644 tests/test_infer/test_server.py diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 1d3a6b497..60ccf15fc 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -11,7 +11,6 @@ Doc: -d '{"prompt":"hello, who are you? ","stream":"False"}'` """ - import argparse import json @@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.inference.config import InferenceConfig +from colossalai.inference.server.chat_service import ChatServing from colossalai.inference.server.completion_service import CompletionServing from colossalai.inference.server.utils import id_generator from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa TIMEOUT_KEEP_ALIVE = 5 # seconds. -app = FastAPI() -engine = None supported_models_dict = {"Llama_Models": ("llama2-7b",)} prompt_template_choices = ["llama", "vicuna"] +async_engine = None +chat_serving = None +completion_serving = None + +app = FastAPI() @app.get("/v0/models") @@ -49,7 +52,7 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") - stream = request_dict.pop("stream", None) + stream = request_dict.pop("stream", "false").lower() request_id = id_generator() generation_config = get_generation_config(request_dict) @@ -61,7 +64,7 @@ async def generate(request: Request) -> Response: ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") - if stream: + if stream == "true": return StreamingResponse(stream_results()) # Non-streaming case @@ -81,17 +84,31 @@ async def generate(request: Request) -> Response: @app.post("/v1/completion") async def create_completion(request: Request): request_dict = await request.json() - stream = request_dict.pop("stream", False) + stream = request_dict.pop("stream", "false").lower() generation_config = get_generation_config(request_dict) result = await completion_serving.create_completion(request, generation_config) ret = {"request_id": result.request_id, "text": result.output} - if stream: + if stream == "true": return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") else: return JSONResponse(content=ret) +@app.post("/v1/chat") +async def create_chat(request: Request): + request_dict = await request.json() + + stream = request_dict.get("stream", "false").lower() + generation_config = get_generation_config(request_dict) + message = await chat_serving.create_chat(request, generation_config) + if stream == "true": + return StreamingResponse(content=message, media_type="text/event-stream") + else: + ret = {"role": message.role, "text": message.content} + return ret + + def get_generation_config(request): generation_config = async_engine.engine.generation_config for arg in request: @@ -175,6 +192,18 @@ def parse_args(): "specified, the model name will be the same as " "the huggingface name.", ) + parser.add_argument( + "--chat-template", + type=str, + default=None, + help="The file path to the chat template, " "or the template in single-line form " "for the specified model", + ) + parser.add_argument( + "--response-role", + type=str, + default="assistant", + help="The role name to return if " "`request.add_generation_prompt=true`.", + ) parser = add_engine_config(parser) return parser.parse_args() @@ -182,7 +211,6 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - inference_config = InferenceConfig.from_dict(vars(args)) model = AutoModelForCausalLM.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model) @@ -191,10 +219,16 @@ if __name__ == "__main__": ) engine = async_engine.engine completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) - + chat_serving = ChatServing( + async_engine, + served_model=model.__class__.__name__, + tokenizer=tokenizer, + response_role=args.response_role, + chat_template=args.chat_template, + ) app.root_path = args.root_path uvicorn.run( - app, + app=app, host=args.host, port=args.port, log_level="debug", diff --git a/colossalai/inference/server/chat_service.py b/colossalai/inference/server/chat_service.py new file mode 100644 index 000000000..d84e82d29 --- /dev/null +++ b/colossalai/inference/server/chat_service.py @@ -0,0 +1,142 @@ +import asyncio +import codecs +import logging + +from fastapi import Request + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator + +logger = logging.getLogger("colossalai-inference") + + +class ChatServing: + def __init__( + self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None + ): + self.engine = engine + self.served_model = served_model + self.tokenizer = tokenizer + self.response_role = response_role + self._load_chat_template(chat_template) + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_chat(self, request: Request, generation_config): + request_dict = await request.json() + messages = request_dict["messages"] + stream = request_dict.pop("stream", "false").lower() + add_generation_prompt = request_dict.pop("add_generation_prompt", False) + request_id = id_generator() + try: + prompt = self.tokenizer.apply_chat_template( + conversation=messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + except Exception as e: + raise RuntimeError(f"Error in applying chat template from request: {str(e)}") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt) + + if stream == "true": + return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id) + else: + return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id) + + async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int): + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request, request_dict) + n = request_dict.get("n", 1) + echo = request_dict.get("echo", "false").lower() + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role)) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if echo == "true": + last_msg_content = "" + if ( + request_dict["messages"] + and isinstance(request_dict["messages"], list) + and request_dict["messages"][-1].get("content") + and request_dict["messages"][-1].get("role") == role + ): + last_msg_content = request_dict["messages"][-1]["content"] + if last_msg_content: + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, message=DeltaMessage(content=last_msg_content) + ) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + result = await result_generator + choice_data = DeltaMessage(content=result.output) + data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {data}\n\n" + + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: Request, + request_dict: dict, + result_generator, + request_id, + ): + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return {"error_msg": "Client disconnected"} + + result = await result_generator + assert result is not None + role = self.get_chat_request_role(request, request_dict) + choice_data = ChatMessage(role=role, content=result.output) + echo = request_dict.get("echo", "false").lower() + + if echo == "true": + last_msg_content = "" + if ( + request.messages + and isinstance(request.messages, list) + and request.messages[-1].get("content") + and request.messages[-1].get("role") == role + ): + last_msg_content = request.messages[-1]["content"] + + full_message = last_msg_content + choice_data.content + choice_data.content = full_message + + return choice_data + + def get_chat_request_role(self, request: Request, request_dict: dict) -> str: + add_generation_prompt = request_dict.get("add_generation_prompt", False) + if add_generation_prompt: + return self.response_role + else: + return request_dict["messages"][-1]["role"] + + def _load_chat_template(self, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}") + elif self.tokenizer.chat_template is not None: + logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}") + else: + logger.warning("No chat template provided. Chat API will not work.") diff --git a/colossalai/inference/server/utils.py b/colossalai/inference/server/utils.py index c10826f73..9eac26576 100644 --- a/colossalai/inference/server/utils.py +++ b/colossalai/inference/server/utils.py @@ -1,3 +1,8 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + # make it singleton class NumericIDGenerator: _instance = None @@ -14,3 +19,18 @@ class NumericIDGenerator: id_generator = NumericIDGenerator() + + +class ChatMessage(BaseModel): + role: str + content: Any + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[Any] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + message: DeltaMessage diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 216dfd1eb..1a3094a27 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -165,12 +165,13 @@ class Sequence: def __repr__(self) -> str: return ( f"(request_id={self.request_id}, " - f"prompt={self.prompt}, " - f"output_token_id={self.output_token_id}," - f"status={self.status.name}, " - f"sample_params={self.sample_params}, " - f"input_len={self.input_len}," - f"output_len={self.output_len})" + f"prompt={self.prompt},\n" + f"output_token_id={self.output_token_id},\n" + f"output={self.output},\n" + f"status={self.status.name},\n" + f"sample_params={self.sample_params},\n" + f"input_len={self.input_len},\n" + f"output_len={self.output_len})\n" ) diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py index 7402a9c04..af00f3c91 100644 --- a/examples/inference/client/locustfile.py +++ b/examples/inference/client/locustfile.py @@ -14,9 +14,37 @@ class QuickstartUser(HttpUser): def completion_streaming(self): self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + @tag("online-chat") + @task(5) + def chat(self): + self.client.post( + "v1/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "False", + }, + ) + + @tag("online-chat") + @task(5) + def chat_streaming(self): + self.client.post( + "v1/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "True", + }, + ) + @tag("offline-generation") @task(5) - def generate_stream(self): + def generate_streaming(self): self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"}) @tag("offline-generation") diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh index 31f4c962e..fe742fda9 100644 --- a/examples/inference/client/run_locust.sh +++ b/examples/inference/client/run_locust.sh @@ -4,9 +4,10 @@ # launch server model_path=${1:-"lmsys/vicuna-7b-v1.3"} +chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" echo "Model Path: $model_path" echo "Starting server..." -python -m colossalai.inference.server.api_server --model $model_path & +python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template & SERVER_PID=$! # waiting time @@ -15,8 +16,10 @@ sleep 60 # Run Locust echo "Starting Locust..." echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." +echo "Test completion api first" locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 - +echo "Test chat api" +locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 # kill Server echo "Stopping server..." kill $SERVER_PID diff --git a/tests/test_infer/test_server.py b/tests/test_infer/test_server.py new file mode 100644 index 000000000..05ac5a264 --- /dev/null +++ b/tests/test_infer/test_server.py @@ -0,0 +1,79 @@ +# inspired by vLLM +import subprocess +import sys +import time + +import pytest +import ray +import requests + +MAX_WAITING_TIME = 300 + +pytestmark = pytest.mark.asyncio + + +@ray.remote(num_gpus=1) +class ServerRunner: + def __init__(self, args): + self.proc = subprocess.Popen( + ["python3", "-m", "colossalai.inference.server.api_server"] + args, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get("http://localhost:8000/v0/models").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_WAITING_TIME: + raise RuntimeError("Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote( + [ + "--model", + "/home/chenjianghai/data/llama-7b-hf", + ] + ) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +async def test_completion(server): + data = {"prompt": "How are you?", "stream": "False"} + response = await server.post("v1/completion", json=data) + assert response is not None + + +async def test_chat(server): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] + data = {"messages": messages, "stream": "False"} + response = await server.post("v1/chat", data) + assert response is not None + + +if __name__ == "__main__": + pytest.main([__file__]) From 7bbb28e48bdb5849d9dfb118d7bf2959d79bbe02 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 11 Apr 2024 10:12:31 +0800 Subject: [PATCH 4/7] [Inference] resolve rebase conflicts fix --- colossalai/inference/core/engine.py | 2 +- colossalai/shardformer/layer/embedding.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 635c3f801..3f456e1f9 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,6 @@ import time from itertools import count -from typing import Dict, List, Optional, Tuple, Union, Iterable +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index cb7eceae4..93df5e522 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -248,7 +248,6 @@ class VocabParallelEmbedding1D(PaddingParallelModule): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. From 61a1b2e798edcbf91ac35966a4047407ad6aa62d Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 8 May 2024 15:14:06 +0800 Subject: [PATCH 5/7] [Inference] Fix bugs and docs for feat/online-server (#5598) * fix test bugs * add do sample test * del useless lines * fix comments * fix tests * delete version tag * delete version tag * add * del test sever * fix test * fix * Revert "add" This reverts commit b9305fb02440d5cd566d32b508bee9f9c13dda15. --- colossalai/inference/config.py | 5 +- colossalai/inference/core/async_engine.py | 52 ++++++++---- colossalai/inference/core/engine.py | 13 ++- colossalai/inference/core/request_handler.py | 2 +- colossalai/inference/server/api_server.py | 40 ++-------- colossalai/shardformer/layer/embedding.py | 2 +- examples/inference/client/locustfile.py | 10 +-- .../test_async_engine/test_async_engine.py | 16 ++-- ...uest_tracker.py => test_request_tracer.py} | 27 +++---- tests/test_infer/test_continuous_batching.py | 18 ++++- tests/test_infer/test_inference_engine.py | 6 +- tests/test_infer/test_server.py | 79 ------------------- 12 files changed, 98 insertions(+), 172 deletions(-) rename tests/test_infer/test_async_engine/{test_request_tracker.py => test_request_tracer.py} (69%) delete mode 100644 tests/test_infer/test_server.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 421c6b589..ee1cd7cfb 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,9 +1,8 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ -import dataclasses import logging -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Any, Dict, Optional, Union import torch @@ -218,7 +217,7 @@ class InferenceConfig: @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": # Get the list of attributes of this dataclass. - attrs = [attr.name for attr in dataclasses.fields(cls)] + attrs = [attr.name for attr in fields(cls)] inference_config_args = {} for attr in attrs: if attr in config_dict: diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index e23d0b90f..6f7ab15d8 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -1,7 +1,7 @@ import asyncio import logging from functools import partial -from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type from colossalai.inference.core.engine import InferenceEngine @@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve logger = logging.getLogger("colossalai-inference") -def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None: msg = "Task finished unexpectedly. This should never happen! " try: try: @@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac class RequstStream: - """A stream of Output for a request that can be - iterated over asynchronously.""" + """ + A stream of Output for a request that can be iterated over asynchronously. + Attributes: 1.request_id: The id of the request. + 2._future: A future that will be set when the request is finished. + Methods: set_result and get_result, results will be set when finished, for once, and + the `self.future` will be set to done. + + """ def __init__(self, request_id: int) -> None: self.request_id = request_id @@ -51,6 +57,10 @@ class RequstStream: class Tracer: """ Recording new requests and finished requests. + Attributes: 1._request_streams: We create one stream for each request to trace the output. + 2._finished_requests: A queue to store the finished requests. + 3._new_requests: New requests will be stored in this queue first, before sending them to the engine. + 4.new_requests_event: An event to notify the engine that there are new requests. """ def __init__(self) -> None: @@ -93,8 +103,8 @@ class Tracer: raise KeyError(f"Request {request_id} already exists.") stream = RequstStream(request_id) + logger.info(f"Added request {request_id}.") self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) - self.new_requests_event.set() return stream @@ -108,6 +118,7 @@ class Tracer: if request_id not in self._request_streams or self._request_streams[request_id].finished: # The request has already finished or been aborted. + # The requests in new_requests will be aborted when try to get them(if marked aborted) return self._request_streams[request_id].set_result(None) @@ -117,9 +128,18 @@ class Tracer: Get new requests from http server. """ new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() + if new_request["request_id"] in finished_requests: + # The request has been aborted. + stream.set_result(None) + continue self._request_streams[stream.request_id] = stream new_requests.append(new_request) @@ -133,7 +153,8 @@ class Tracer: class _AsyncInferenceEngine(InferenceEngine): """ - Async methods for Inference Engine. + Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for + Methods: 1. async_step: The async version of Engine.step() """ async def async_step(self) -> List[str]: @@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine): if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) - # Return: List[Sequence] + finished_sequences = self.request_handler.update() for sequence in finished_sequences: sequence.output = self.tokenizer.decode(sequence.output_token_id) - return finished_sequences, self.request_handler.current_requests_in_batch() > 0 + return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0 class AsyncInferenceEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for the InferenceEngine class. This class is used to wrap the InferenceEngine class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there are - requests in the waiting queue. The generate method yields the outputs - from the InferenceEngine to the caller. + requests. Note that this class does not hold model directly, when incoming a new + request, it first called `add_request` and the Tracer will record the request, putting + it to the background `InferenceEngine`(done in background loop) to process. You can + consider this engine as an interface. """ _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine @@ -253,7 +275,7 @@ class AsyncInferenceEngine: prompt_token_ids: Optional[List[int]] = None, ) -> RequstStream: """ - Add a request to the background tracker(waitting queue), start the background loop if needed. + Add a request to the background tracker(waiting queue), start the background loop if needed. """ if not self.background_loop_status: if self.start_engine_loop: @@ -276,14 +298,12 @@ class AsyncInferenceEngine: """ Generate output from a request. It receives the request from http server, adds it into the waitting queue of Async Engine and streams the output sequence. - """ try: stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) return await stream.get_result() except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. + # If there is an exception or coroutine is cancelled, abort the request. self._abort(request_id) raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 3f456e1f9..02a8c92a2 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -527,10 +527,15 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): +<<<<<<< HEAD if isinstance(prompts, str) and isinstance(request_ids, int): prompts = [prompts] request_ids = [request_ids] +======= + if prompts is not None or prompts_token_ids is not None: + self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) +>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598) if prompts is not None or prompts_token_ids is not None: gen_config_dict = generation_config.to_dict() if generation_config is not None else {} @@ -612,6 +617,9 @@ class InferenceEngine: block_size = self.inference_config.block_size + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + if prompts is not None and not isinstance(prompts, list): prompts = [prompts] @@ -621,9 +629,10 @@ class InferenceEngine: "input_ids" ] + # list of torch Tensor if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: @@ -738,8 +747,6 @@ class InferenceEngine: logits = logits[:, -1, :] next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) - - self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 12c9cebf7..03b4d2305 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -328,7 +328,7 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() - def current_requests_in_batch(self) -> int: + def total_requests_in_batch_bucket(self) -> int: return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size def search_tokens(self, generation_config: GenerationConfig, logits): diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 60ccf15fc..dfbd2c906 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -6,9 +6,10 @@ Doc: Usage: (for local user) - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api - - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ + - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \ -H 'Content-Type: application/json' \ -d '{"prompt":"hello, who are you? ","stream":"False"}'` + Version: V1.0 """ import argparse @@ -36,7 +37,8 @@ completion_serving = None app = FastAPI() -@app.get("/v0/models") +# NOTE: (CjhHa1) models are still under development, need to be updated +@app.get("/models") def get_available_models() -> Response: return JSONResponse(supported_models_dict) @@ -81,7 +83,7 @@ async def generate(request: Request) -> Response: return JSONResponse(ret) -@app.post("/v1/completion") +@app.post("/completion") async def create_completion(request: Request): request_dict = await request.json() stream = request_dict.pop("stream", "false").lower() @@ -95,7 +97,7 @@ async def create_completion(request: Request): return JSONResponse(content=ret) -@app.post("/v1/chat") +@app.post("/chat") async def create_chat(request: Request): request_dict = await request.json() @@ -127,14 +129,6 @@ def add_engine_config(parser): help="model context length. If unspecified, " "will be automatically derived from the model.", ) # Parallel arguments - parser.add_argument( - "--worker-use-ray", - action="store_true", - help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU", - ) - - parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages") - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") # KV cache arguments @@ -149,28 +143,6 @@ def add_engine_config(parser): default=None, help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", ) - - # Quantization settings. - parser.add_argument( - "--quantization", - "-q", - type=str, - choices=["awq", "gptq", "squeezellm", None], - default=None, - help="Method used to quantize the weights. If " - "None, we first check the `quantization_config` " - "attribute in the model config file. If that is " - "None, we assume the model weights are not " - "quantized and use `dtype` to determine the data " - "type of the weights.", - ) - parser.add_argument( - "--enforce-eager", - action="store_true", - help="Always use eager-mode PyTorch. If False, " - "will use eager mode and CUDA graph in hybrid " - "for maximal performance and flexibility.", - ) return parser diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 93df5e522..9b77774aa 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -248,7 +248,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - + :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py index af00f3c91..a65c8b667 100644 --- a/examples/inference/client/locustfile.py +++ b/examples/inference/client/locustfile.py @@ -7,18 +7,18 @@ class QuickstartUser(HttpUser): @tag("online-generation") @task(5) def completion(self): - self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) @tag("online-generation") @task(5) def completion_streaming(self): - self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) @tag("online-chat") @task(5) def chat(self): self.client.post( - "v1/chat", + "/chat", json={ "converation": [ {"role": "system", "content": "you are a helpful assistant"}, @@ -32,7 +32,7 @@ class QuickstartUser(HttpUser): @task(5) def chat_streaming(self): self.client.post( - "v1/chat", + "/chat", json={ "converation": [ {"role": "system", "content": "you are a helpful assistant"}, @@ -55,4 +55,4 @@ class QuickstartUser(HttpUser): @tag("online-generation", "offline-generation") @task def get_models(self): - self.client.get("/v0/models") + self.client.get("/models") diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py index ebca11c72..ac532b1b1 100644 --- a/tests/test_infer/test_async_engine/test_async_engine.py +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine @dataclass -class SequenceTpye: +class MockSequence: request_id: int @@ -20,7 +20,11 @@ class MockEngine: async def async_step(self): self.step_calls += 1 - return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] + return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False) + + def add_single_request(self, **kwargs): + del kwargs + self.add_request_calls += 1 def generate(self, request_id): self.request_id = request_id @@ -37,14 +41,14 @@ class MockEngine: self.abort_request_calls += 1 -class MockAsyncLLMEngine(AsyncInferenceEngine): +class MockAsyncInferenceEngine(AsyncInferenceEngine): def _init_engine(self, *args, **kwargs): return MockEngine() @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine = MockAsyncInferenceEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -74,7 +78,3 @@ async def test_new_requests_event(): await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == 5 - - -if __name__ == "__main__": - test_new_requests_event() diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracer.py similarity index 69% rename from tests/test_infer/test_async_engine/test_request_tracker.py rename to tests/test_infer/test_async_engine/test_request_tracer.py index 9a797a862..14bcb9628 100644 --- a/tests/test_infer/test_async_engine/test_request_tracker.py +++ b/tests/test_infer/test_async_engine/test_request_tracer.py @@ -1,6 +1,6 @@ import pytest -from colossalai.inference.core.async_engine import RequestTracker +from colossalai.inference.core.async_engine import Tracer from colossalai.inference.struct import Sequence @@ -15,27 +15,25 @@ class SampleEvent: self.flag = False -def test_request_tracker(): - tracker = RequestTracker() +def test_request_tracer(): + tracker = Tracer() tracker.new_requests_event = SampleEvent() stream_1 = tracker.add_request(1) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 1 assert new[0]["request_id"] == 1 - assert not finished assert not stream_1.finished stream_2 = tracker.add_request(2) stream_3 = tracker.add_request(3) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 2 assert new[0]["request_id"] == 2 assert new[1]["request_id"] == 3 - assert not finished assert not stream_2.finished assert not stream_3.finished @@ -45,28 +43,21 @@ def test_request_tracker(): assert not tracker.new_requests_event.flag tracker.abort_request(1) - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 1 in finished + new = tracker.get_new_requests() assert not new - assert stream_1.finished stream_4 = tracker.add_request(4) tracker.abort_request(4) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 4 in finished + new = tracker.get_new_requests() assert not new assert stream_4.finished stream_5 = tracker.add_request(5) assert tracker.new_requests_event.flag tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag - assert len(finished) == 1 - assert 2 in finished assert len(new) == 1 assert new[0]["request_id"] == 5 assert stream_2.finished @@ -74,4 +65,4 @@ def test_request_tracker(): if __name__ == "__main__": - test_request_tracker() + test_request_tracer() diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py index 0b0d92c7c..350ed473e 100644 --- a/tests/test_infer/test_continuous_batching.py +++ b/tests/test_infer/test_continuous_batching.py @@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length): @parameterize( - "max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 + "test_config", + [ + { + "max_batch_size": 8, + "max_output_len": 512, + "max_input_len": 64, + "do_sample": False, + } + ], ) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(test_config, use_engine=False, prompt_template=None): setup_seed(20) + max_batch_size = test_config["max_batch_size"] + max_input_len = test_config["max_input_len"] + max_output_len = test_config["max_output_len"] + do_sample = test_config["do_sample"] + top_p = 0.5 + top_k = 50 tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() model = model.eval() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index be1330898..919a10077 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru ) ).cuda() model = model.eval() - inputs = [ "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下武汉,", @@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: @@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru inputs = inputs.cuda() generation_config = GenerationConfig( do_sample=do_sample, + dtype="fp32", top_p=top_p, top_k=top_k, pad_token_id=tokenizer.pad_token_id, diff --git a/tests/test_infer/test_server.py b/tests/test_infer/test_server.py deleted file mode 100644 index 05ac5a264..000000000 --- a/tests/test_infer/test_server.py +++ /dev/null @@ -1,79 +0,0 @@ -# inspired by vLLM -import subprocess -import sys -import time - -import pytest -import ray -import requests - -MAX_WAITING_TIME = 300 - -pytestmark = pytest.mark.asyncio - - -@ray.remote(num_gpus=1) -class ServerRunner: - def __init__(self, args): - self.proc = subprocess.Popen( - ["python3", "-m", "colossalai.inference.server.api_server"] + args, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get("http://localhost:8000/v0/models").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_WAITING_TIME: - raise RuntimeError("Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - -@pytest.fixture(scope="session") -def server(): - ray.init() - server_runner = ServerRunner.remote( - [ - "--model", - "/home/chenjianghai/data/llama-7b-hf", - ] - ) - ray.get(server_runner.ready.remote()) - yield server_runner - ray.shutdown() - - -async def test_completion(server): - data = {"prompt": "How are you?", "stream": "False"} - response = await server.post("v1/completion", json=data) - assert response is not None - - -async def test_chat(server): - messages = [ - {"role": "system", "content": "you are a helpful assistant"}, - {"role": "user", "content": "what is 1+1?"}, - ] - data = {"messages": messages, "stream": "False"} - response = await server.post("v1/chat", data) - assert response is not None - - -if __name__ == "__main__": - pytest.main([__file__]) From bc9063adf1598c3be32fc2d12577d76b9daa79bf Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 8 May 2024 10:36:42 +0000 Subject: [PATCH 6/7] resolve rebase conflicts on Branch feat/online-serving --- colossalai/inference/core/engine.py | 13 +++------ colossalai/inference/server/README.md | 27 +++++++++++++++++++ .../kernel/triton/no_pad_rotary_embedding.py | 2 -- tests/test_infer/test_continuous_batching.py | 2 +- 4 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 colossalai/inference/server/README.md diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 02a8c92a2..1ced54dd7 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -527,16 +527,9 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): -<<<<<<< HEAD - if isinstance(prompts, str) and isinstance(request_ids, int): - prompts = [prompts] - request_ids = [request_ids] -======= - if prompts is not None or prompts_token_ids is not None: - self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) ->>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598) - + prompts = [prompts] + request_ids = [request_ids] if prompts is not None or prompts_token_ids is not None: gen_config_dict = generation_config.to_dict() if generation_config is not None else {} self.add_request( @@ -545,7 +538,7 @@ class InferenceEngine: prompts_token_ids=prompts_token_ids, **gen_config_dict, ) - + output_seqs_list = [] total_tokens_list = [] diff --git a/colossalai/inference/server/README.md b/colossalai/inference/server/README.md new file mode 100644 index 000000000..8b5f29fc0 --- /dev/null +++ b/colossalai/inference/server/README.md @@ -0,0 +1,27 @@ +# Online Service +Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and +you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill +the blank quickly. + +# Usage +```bash +# First, Lauch an API locally. +python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %} +{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" + + +# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + +# For completion service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +# For chat service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation": + [{"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"},], + "stream": "False",}' +# If you just want to test a simple generation, turn to generate api +curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +``` +We also support streaming output, simply change the `stream` to `True` in the request body. diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 3a1de6d6a..e0da816bd 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -598,8 +598,6 @@ def decoding_fused_rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) - assert k.size(1) == v.size(1) - assert k_cache.size(-1) == v_cache.size(-1) if head_dim >= 512: num_warps = 16 diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py index 350ed473e..a88798619 100644 --- a/tests/test_infer/test_continuous_batching.py +++ b/tests/test_infer/test_continuous_batching.py @@ -89,7 +89,7 @@ def check_continuous_batching(prompt_template): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_continuous_batching() From 5d9a49483d98ccd4bebebbfd039162caceefe6bd Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 9 May 2024 05:44:05 +0000 Subject: [PATCH 7/7] [Inference] Add example test_ci script --- examples/inference/client/test_ci.sh | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 examples/inference/client/test_ci.sh diff --git a/examples/inference/client/test_ci.sh b/examples/inference/client/test_ci.sh new file mode 100644 index 000000000..b130fc486 --- /dev/null +++ b/examples/inference/client/test_ci.sh @@ -0,0 +1,4 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" + +# bash ./run_benchmark.sh