Browse Source

[Online Server] Chat Api for streaming and not streaming response (#5470)

* fix bugs

* fix bugs

* fix api server

* fix api server

* add chat api and test

* del request.n
feat/online-serving
Jianghai 8 months ago committed by CjhHa1
parent
commit
c064032865
  1. 54
      colossalai/inference/server/api_server.py
  2. 142
      colossalai/inference/server/chat_service.py
  3. 20
      colossalai/inference/server/utils.py
  4. 13
      colossalai/inference/struct.py
  5. 30
      examples/inference/client/locustfile.py
  6. 7
      examples/inference/client/run_locust.sh
  7. 79
      tests/test_infer/test_server.py

54
colossalai/inference/server/api_server.py

@ -11,7 +11,6 @@ Doc:
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
"""
import argparse
import json
@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.server.chat_service import ChatServing
from colossalai.inference.server.completion_service import CompletionServing
from colossalai.inference.server.utils import id_generator
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
prompt_template_choices = ["llama", "vicuna"]
async_engine = None
chat_serving = None
completion_serving = None
app = FastAPI()
@app.get("/v0/models")
@ -49,7 +52,7 @@ async def generate(request: Request) -> Response:
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", None)
stream = request_dict.pop("stream", "false").lower()
request_id = id_generator()
generation_config = get_generation_config(request_dict)
@ -61,7 +64,7 @@ async def generate(request: Request) -> Response:
ret = {"text": request_output[len(prompt) :]}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
if stream == "true":
return StreamingResponse(stream_results())
# Non-streaming case
@ -81,17 +84,31 @@ async def generate(request: Request) -> Response:
@app.post("/v1/completion")
async def create_completion(request: Request):
request_dict = await request.json()
stream = request_dict.pop("stream", False)
stream = request_dict.pop("stream", "false").lower()
generation_config = get_generation_config(request_dict)
result = await completion_serving.create_completion(request, generation_config)
ret = {"request_id": result.request_id, "text": result.output}
if stream:
if stream == "true":
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
else:
return JSONResponse(content=ret)
@app.post("/v1/chat")
async def create_chat(request: Request):
request_dict = await request.json()
stream = request_dict.get("stream", "false").lower()
generation_config = get_generation_config(request_dict)
message = await chat_serving.create_chat(request, generation_config)
if stream == "true":
return StreamingResponse(content=message, media_type="text/event-stream")
else:
ret = {"role": message.role, "text": message.content}
return ret
def get_generation_config(request):
generation_config = async_engine.engine.generation_config
for arg in request:
@ -175,6 +192,18 @@ def parse_args():
"specified, the model name will be the same as "
"the huggingface name.",
)
parser.add_argument(
"--chat-template",
type=str,
default=None,
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
)
parser.add_argument(
"--response-role",
type=str,
default="assistant",
help="The role name to return if " "`request.add_generation_prompt=true`.",
)
parser = add_engine_config(parser)
return parser.parse_args()
@ -182,7 +211,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
inference_config = InferenceConfig.from_dict(vars(args))
model = AutoModelForCausalLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
@ -191,10 +219,16 @@ if __name__ == "__main__":
)
engine = async_engine.engine
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
chat_serving = ChatServing(
async_engine,
served_model=model.__class__.__name__,
tokenizer=tokenizer,
response_role=args.response_role,
chat_template=args.chat_template,
)
app.root_path = args.root_path
uvicorn.run(
app,
app=app,
host=args.host,
port=args.port,
log_level="debug",

142
colossalai/inference/server/chat_service.py

@ -0,0 +1,142 @@
import asyncio
import codecs
import logging
from fastapi import Request
from colossalai.inference.core.async_engine import AsyncInferenceEngine
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
logger = logging.getLogger("colossalai-inference")
class ChatServing:
def __init__(
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
):
self.engine = engine
self.served_model = served_model
self.tokenizer = tokenizer
self.response_role = response_role
self._load_chat_template(chat_template)
try:
asyncio.get_running_loop()
except RuntimeError:
pass
async def create_chat(self, request: Request, generation_config):
request_dict = await request.json()
messages = request_dict["messages"]
stream = request_dict.pop("stream", "false").lower()
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
request_id = id_generator()
try:
prompt = self.tokenizer.apply_chat_template(
conversation=messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
except Exception as e:
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
if stream == "true":
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
else:
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
# Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request, request_dict)
n = request_dict.get("n", 1)
echo = request_dict.get("echo", "false").lower()
for i in range(n):
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
data = choice_data.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if echo == "true":
last_msg_content = ""
if (
request_dict["messages"]
and isinstance(request_dict["messages"], list)
and request_dict["messages"][-1].get("content")
and request_dict["messages"][-1].get("role") == role
):
last_msg_content = request_dict["messages"][-1]["content"]
if last_msg_content:
for i in range(n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, message=DeltaMessage(content=last_msg_content)
)
data = choice_data.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
result = await result_generator
choice_data = DeltaMessage(content=result.output)
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self,
request: Request,
request_dict: dict,
result_generator,
request_id,
):
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
result = await result_generator
assert result is not None
role = self.get_chat_request_role(request, request_dict)
choice_data = ChatMessage(role=role, content=result.output)
echo = request_dict.get("echo", "false").lower()
if echo == "true":
last_msg_content = ""
if (
request.messages
and isinstance(request.messages, list)
and request.messages[-1].get("content")
and request.messages[-1].get("role") == role
):
last_msg_content = request.messages[-1]["content"]
full_message = last_msg_content + choice_data.content
choice_data.content = full_message
return choice_data
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
add_generation_prompt = request_dict.get("add_generation_prompt", False)
if add_generation_prompt:
return self.response_role
else:
return request_dict["messages"][-1]["role"]
def _load_chat_template(self, chat_template):
if chat_template is not None:
try:
with open(chat_template, "r") as f:
self.tokenizer.chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
elif self.tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")

20
colossalai/inference/server/utils.py

@ -1,3 +1,8 @@
from typing import Any, Optional
from pydantic import BaseModel
# make it singleton
class NumericIDGenerator:
_instance = None
@ -14,3 +19,18 @@ class NumericIDGenerator:
id_generator = NumericIDGenerator()
class ChatMessage(BaseModel):
role: str
content: Any
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[Any] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
message: DeltaMessage

13
colossalai/inference/struct.py

@ -165,12 +165,13 @@ class Sequence:
def __repr__(self) -> str:
return (
f"(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"output_token_id={self.output_token_id},"
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"input_len={self.input_len},"
f"output_len={self.output_len})"
f"prompt={self.prompt},\n"
f"output_token_id={self.output_token_id},\n"
f"output={self.output},\n"
f"status={self.status.name},\n"
f"sample_params={self.sample_params},\n"
f"input_len={self.input_len},\n"
f"output_len={self.output_len})\n"
)

30
examples/inference/client/locustfile.py

@ -14,9 +14,37 @@ class QuickstartUser(HttpUser):
def completion_streaming(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
@tag("online-chat")
@task(5)
def chat(self):
self.client.post(
"v1/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
],
"stream": "False",
},
)
@tag("online-chat")
@task(5)
def chat_streaming(self):
self.client.post(
"v1/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
],
"stream": "True",
},
)
@tag("offline-generation")
@task(5)
def generate_stream(self):
def generate_streaming(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
@tag("offline-generation")

7
examples/inference/client/run_locust.sh

@ -4,9 +4,10 @@
# launch server
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
echo "Model Path: $model_path"
echo "Starting server..."
python -m colossalai.inference.server.api_server --model $model_path &
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
SERVER_PID=$!
# waiting time
@ -15,8 +16,10 @@ sleep 60
# Run Locust
echo "Starting Locust..."
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
echo "Test completion api first"
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
echo "Test chat api"
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
# kill Server
echo "Stopping server..."
kill $SERVER_PID

79
tests/test_infer/test_server.py

@ -0,0 +1,79 @@
# inspired by vLLM
import subprocess
import sys
import time
import pytest
import ray
import requests
MAX_WAITING_TIME = 300
pytestmark = pytest.mark.asyncio
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
self.proc = subprocess.Popen(
["python3", "-m", "colossalai.inference.server.api_server"] + args,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get("http://localhost:8000/v0/models").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_WAITING_TIME:
raise RuntimeError("Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote(
[
"--model",
"/home/chenjianghai/data/llama-7b-hf",
]
)
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
async def test_completion(server):
data = {"prompt": "How are you?", "stream": "False"}
response = await server.post("v1/completion", json=data)
assert response is not None
async def test_chat(server):
messages = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
]
data = {"messages": messages, "stream": "False"}
response = await server.post("v1/chat", data)
assert response is not None
if __name__ == "__main__":
pytest.main([__file__])
Loading…
Cancel
Save