From c06403286567f62cb0a6dfc5e075cf60e291cea9 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Sun, 7 Apr 2024 14:45:43 +0800 Subject: [PATCH] [Online Server] Chat Api for streaming and not streaming response (#5470) * fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n --- colossalai/inference/server/api_server.py | 54 ++++++-- colossalai/inference/server/chat_service.py | 142 ++++++++++++++++++++ colossalai/inference/server/utils.py | 20 +++ colossalai/inference/struct.py | 13 +- examples/inference/client/locustfile.py | 30 ++++- examples/inference/client/run_locust.sh | 7 +- tests/test_infer/test_server.py | 79 +++++++++++ 7 files changed, 326 insertions(+), 19 deletions(-) create mode 100644 colossalai/inference/server/chat_service.py create mode 100644 tests/test_infer/test_server.py diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 1d3a6b497..60ccf15fc 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -11,7 +11,6 @@ Doc: -d '{"prompt":"hello, who are you? ","stream":"False"}'` """ - import argparse import json @@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.inference.config import InferenceConfig +from colossalai.inference.server.chat_service import ChatServing from colossalai.inference.server.completion_service import CompletionServing from colossalai.inference.server.utils import id_generator from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa TIMEOUT_KEEP_ALIVE = 5 # seconds. -app = FastAPI() -engine = None supported_models_dict = {"Llama_Models": ("llama2-7b",)} prompt_template_choices = ["llama", "vicuna"] +async_engine = None +chat_serving = None +completion_serving = None + +app = FastAPI() @app.get("/v0/models") @@ -49,7 +52,7 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") - stream = request_dict.pop("stream", None) + stream = request_dict.pop("stream", "false").lower() request_id = id_generator() generation_config = get_generation_config(request_dict) @@ -61,7 +64,7 @@ async def generate(request: Request) -> Response: ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") - if stream: + if stream == "true": return StreamingResponse(stream_results()) # Non-streaming case @@ -81,17 +84,31 @@ async def generate(request: Request) -> Response: @app.post("/v1/completion") async def create_completion(request: Request): request_dict = await request.json() - stream = request_dict.pop("stream", False) + stream = request_dict.pop("stream", "false").lower() generation_config = get_generation_config(request_dict) result = await completion_serving.create_completion(request, generation_config) ret = {"request_id": result.request_id, "text": result.output} - if stream: + if stream == "true": return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") else: return JSONResponse(content=ret) +@app.post("/v1/chat") +async def create_chat(request: Request): + request_dict = await request.json() + + stream = request_dict.get("stream", "false").lower() + generation_config = get_generation_config(request_dict) + message = await chat_serving.create_chat(request, generation_config) + if stream == "true": + return StreamingResponse(content=message, media_type="text/event-stream") + else: + ret = {"role": message.role, "text": message.content} + return ret + + def get_generation_config(request): generation_config = async_engine.engine.generation_config for arg in request: @@ -175,6 +192,18 @@ def parse_args(): "specified, the model name will be the same as " "the huggingface name.", ) + parser.add_argument( + "--chat-template", + type=str, + default=None, + help="The file path to the chat template, " "or the template in single-line form " "for the specified model", + ) + parser.add_argument( + "--response-role", + type=str, + default="assistant", + help="The role name to return if " "`request.add_generation_prompt=true`.", + ) parser = add_engine_config(parser) return parser.parse_args() @@ -182,7 +211,6 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - inference_config = InferenceConfig.from_dict(vars(args)) model = AutoModelForCausalLM.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model) @@ -191,10 +219,16 @@ if __name__ == "__main__": ) engine = async_engine.engine completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) - + chat_serving = ChatServing( + async_engine, + served_model=model.__class__.__name__, + tokenizer=tokenizer, + response_role=args.response_role, + chat_template=args.chat_template, + ) app.root_path = args.root_path uvicorn.run( - app, + app=app, host=args.host, port=args.port, log_level="debug", diff --git a/colossalai/inference/server/chat_service.py b/colossalai/inference/server/chat_service.py new file mode 100644 index 000000000..d84e82d29 --- /dev/null +++ b/colossalai/inference/server/chat_service.py @@ -0,0 +1,142 @@ +import asyncio +import codecs +import logging + +from fastapi import Request + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator + +logger = logging.getLogger("colossalai-inference") + + +class ChatServing: + def __init__( + self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None + ): + self.engine = engine + self.served_model = served_model + self.tokenizer = tokenizer + self.response_role = response_role + self._load_chat_template(chat_template) + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_chat(self, request: Request, generation_config): + request_dict = await request.json() + messages = request_dict["messages"] + stream = request_dict.pop("stream", "false").lower() + add_generation_prompt = request_dict.pop("add_generation_prompt", False) + request_id = id_generator() + try: + prompt = self.tokenizer.apply_chat_template( + conversation=messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + except Exception as e: + raise RuntimeError(f"Error in applying chat template from request: {str(e)}") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt) + + if stream == "true": + return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id) + else: + return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id) + + async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int): + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request, request_dict) + n = request_dict.get("n", 1) + echo = request_dict.get("echo", "false").lower() + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role)) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if echo == "true": + last_msg_content = "" + if ( + request_dict["messages"] + and isinstance(request_dict["messages"], list) + and request_dict["messages"][-1].get("content") + and request_dict["messages"][-1].get("role") == role + ): + last_msg_content = request_dict["messages"][-1]["content"] + if last_msg_content: + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, message=DeltaMessage(content=last_msg_content) + ) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + result = await result_generator + choice_data = DeltaMessage(content=result.output) + data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {data}\n\n" + + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: Request, + request_dict: dict, + result_generator, + request_id, + ): + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return {"error_msg": "Client disconnected"} + + result = await result_generator + assert result is not None + role = self.get_chat_request_role(request, request_dict) + choice_data = ChatMessage(role=role, content=result.output) + echo = request_dict.get("echo", "false").lower() + + if echo == "true": + last_msg_content = "" + if ( + request.messages + and isinstance(request.messages, list) + and request.messages[-1].get("content") + and request.messages[-1].get("role") == role + ): + last_msg_content = request.messages[-1]["content"] + + full_message = last_msg_content + choice_data.content + choice_data.content = full_message + + return choice_data + + def get_chat_request_role(self, request: Request, request_dict: dict) -> str: + add_generation_prompt = request_dict.get("add_generation_prompt", False) + if add_generation_prompt: + return self.response_role + else: + return request_dict["messages"][-1]["role"] + + def _load_chat_template(self, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}") + elif self.tokenizer.chat_template is not None: + logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}") + else: + logger.warning("No chat template provided. Chat API will not work.") diff --git a/colossalai/inference/server/utils.py b/colossalai/inference/server/utils.py index c10826f73..9eac26576 100644 --- a/colossalai/inference/server/utils.py +++ b/colossalai/inference/server/utils.py @@ -1,3 +1,8 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + # make it singleton class NumericIDGenerator: _instance = None @@ -14,3 +19,18 @@ class NumericIDGenerator: id_generator = NumericIDGenerator() + + +class ChatMessage(BaseModel): + role: str + content: Any + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[Any] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + message: DeltaMessage diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 216dfd1eb..1a3094a27 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -165,12 +165,13 @@ class Sequence: def __repr__(self) -> str: return ( f"(request_id={self.request_id}, " - f"prompt={self.prompt}, " - f"output_token_id={self.output_token_id}," - f"status={self.status.name}, " - f"sample_params={self.sample_params}, " - f"input_len={self.input_len}," - f"output_len={self.output_len})" + f"prompt={self.prompt},\n" + f"output_token_id={self.output_token_id},\n" + f"output={self.output},\n" + f"status={self.status.name},\n" + f"sample_params={self.sample_params},\n" + f"input_len={self.input_len},\n" + f"output_len={self.output_len})\n" ) diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py index 7402a9c04..af00f3c91 100644 --- a/examples/inference/client/locustfile.py +++ b/examples/inference/client/locustfile.py @@ -14,9 +14,37 @@ class QuickstartUser(HttpUser): def completion_streaming(self): self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + @tag("online-chat") + @task(5) + def chat(self): + self.client.post( + "v1/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "False", + }, + ) + + @tag("online-chat") + @task(5) + def chat_streaming(self): + self.client.post( + "v1/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "True", + }, + ) + @tag("offline-generation") @task(5) - def generate_stream(self): + def generate_streaming(self): self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"}) @tag("offline-generation") diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh index 31f4c962e..fe742fda9 100644 --- a/examples/inference/client/run_locust.sh +++ b/examples/inference/client/run_locust.sh @@ -4,9 +4,10 @@ # launch server model_path=${1:-"lmsys/vicuna-7b-v1.3"} +chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" echo "Model Path: $model_path" echo "Starting server..." -python -m colossalai.inference.server.api_server --model $model_path & +python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template & SERVER_PID=$! # waiting time @@ -15,8 +16,10 @@ sleep 60 # Run Locust echo "Starting Locust..." echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." +echo "Test completion api first" locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 - +echo "Test chat api" +locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 # kill Server echo "Stopping server..." kill $SERVER_PID diff --git a/tests/test_infer/test_server.py b/tests/test_infer/test_server.py new file mode 100644 index 000000000..05ac5a264 --- /dev/null +++ b/tests/test_infer/test_server.py @@ -0,0 +1,79 @@ +# inspired by vLLM +import subprocess +import sys +import time + +import pytest +import ray +import requests + +MAX_WAITING_TIME = 300 + +pytestmark = pytest.mark.asyncio + + +@ray.remote(num_gpus=1) +class ServerRunner: + def __init__(self, args): + self.proc = subprocess.Popen( + ["python3", "-m", "colossalai.inference.server.api_server"] + args, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get("http://localhost:8000/v0/models").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_WAITING_TIME: + raise RuntimeError("Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote( + [ + "--model", + "/home/chenjianghai/data/llama-7b-hf", + ] + ) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +async def test_completion(server): + data = {"prompt": "How are you?", "stream": "False"} + response = await server.post("v1/completion", json=data) + assert response is not None + + +async def test_chat(server): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] + data = {"messages": messages, "stream": "False"} + response = await server.post("v1/chat", data) + assert response is not None + + +if __name__ == "__main__": + pytest.main([__file__])