mirror of https://github.com/hpcaitech/ColossalAI
[Online Server] Chat Api for streaming and not streaming response (#5470)
* fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.nfeat/online-serving
parent
de378cd2ab
commit
c064032865
|
@ -11,7 +11,6 @@ Doc:
|
|||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.chat_service import ChatServing
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
async_engine = None
|
||||
chat_serving = None
|
||||
completion_serving = None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/v0/models")
|
||||
|
@ -49,7 +52,7 @@ async def generate(request: Request) -> Response:
|
|||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", None)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
|
@ -61,7 +64,7 @@ async def generate(request: Request) -> Response:
|
|||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
|
@ -81,17 +84,31 @@ async def generate(request: Request) -> Response:
|
|||
@app.post("/v1/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", False)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
@app.post("/v1/chat")
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
stream = request_dict.get("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
message = await chat_serving.create_chat(request, generation_config)
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||
else:
|
||||
ret = {"role": message.role, "text": message.content}
|
||||
return ret
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
|
@ -175,6 +192,18 @@ def parse_args():
|
|||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if " "`request.add_generation_prompt=true`.",
|
||||
)
|
||||
parser = add_engine_config(parser)
|
||||
|
||||
return parser.parse_args()
|
||||
|
@ -182,7 +211,6 @@ def parse_args():
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
|
@ -191,10 +219,16 @@ if __name__ == "__main__":
|
|||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
|
||||
chat_serving = ChatServing(
|
||||
async_engine,
|
||||
served_model=model.__class__.__name__,
|
||||
tokenizer=tokenizer,
|
||||
response_role=args.response_role,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(
|
||||
app,
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import asyncio
|
||||
import codecs
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
|
||||
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
class ChatServing:
|
||||
def __init__(
|
||||
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
|
||||
):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
self.tokenizer = tokenizer
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_chat(self, request: Request, generation_config):
|
||||
request_dict = await request.json()
|
||||
messages = request_dict["messages"]
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
|
||||
request_id = id_generator()
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
if stream == "true":
|
||||
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
|
||||
else:
|
||||
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
|
||||
|
||||
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
n = request_dict.get("n", 1)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request_dict["messages"]
|
||||
and isinstance(request_dict["messages"], list)
|
||||
and request_dict["messages"][-1].get("content")
|
||||
and request_dict["messages"][-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request_dict["messages"][-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, message=DeltaMessage(content=last_msg_content)
|
||||
)
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
result = await result_generator
|
||||
choice_data = DeltaMessage(content=result.output)
|
||||
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: Request,
|
||||
request_dict: dict,
|
||||
result_generator,
|
||||
request_id,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
|
||||
result = await result_generator
|
||||
assert result is not None
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
choice_data = ChatMessage(role=role, content=result.output)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get("content")
|
||||
and request.messages[-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
full_message = last_msg_content + choice_data.content
|
||||
choice_data.content = full_message
|
||||
|
||||
return choice_data
|
||||
|
||||
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
|
||||
add_generation_prompt = request_dict.get("add_generation_prompt", False)
|
||||
if add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request_dict["messages"][-1]["role"]
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
self.tokenizer.chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
|
||||
elif self.tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
|
@ -1,3 +1,8 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# make it singleton
|
||||
class NumericIDGenerator:
|
||||
_instance = None
|
||||
|
@ -14,3 +19,18 @@ class NumericIDGenerator:
|
|||
|
||||
|
||||
id_generator = NumericIDGenerator()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
message: DeltaMessage
|
||||
|
|
|
@ -165,12 +165,13 @@ class Sequence:
|
|||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"output_token_id={self.output_token_id},"
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"input_len={self.input_len},"
|
||||
f"output_len={self.output_len})"
|
||||
f"prompt={self.prompt},\n"
|
||||
f"output_token_id={self.output_token_id},\n"
|
||||
f"output={self.output},\n"
|
||||
f"status={self.status.name},\n"
|
||||
f"sample_params={self.sample_params},\n"
|
||||
f"input_len={self.input_len},\n"
|
||||
f"output_len={self.output_len})\n"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -14,9 +14,37 @@ class QuickstartUser(HttpUser):
|
|||
def completion_streaming(self):
|
||||
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||
|
||||
@tag("online-chat")
|
||||
@task(5)
|
||||
def chat(self):
|
||||
self.client.post(
|
||||
"v1/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
],
|
||||
"stream": "False",
|
||||
},
|
||||
)
|
||||
|
||||
@tag("online-chat")
|
||||
@task(5)
|
||||
def chat_streaming(self):
|
||||
self.client.post(
|
||||
"v1/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
],
|
||||
"stream": "True",
|
||||
},
|
||||
)
|
||||
|
||||
@tag("offline-generation")
|
||||
@task(5)
|
||||
def generate_stream(self):
|
||||
def generate_streaming(self):
|
||||
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
|
||||
|
||||
@tag("offline-generation")
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
|
||||
# launch server
|
||||
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||
echo "Model Path: $model_path"
|
||||
echo "Starting server..."
|
||||
python -m colossalai.inference.server.api_server --model $model_path &
|
||||
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
|
||||
SERVER_PID=$!
|
||||
|
||||
# waiting time
|
||||
|
@ -15,8 +16,10 @@ sleep 60
|
|||
# Run Locust
|
||||
echo "Starting Locust..."
|
||||
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||
echo "Test completion api first"
|
||||
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||
|
||||
echo "Test chat api"
|
||||
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||
# kill Server
|
||||
echo "Stopping server..."
|
||||
kill $SERVER_PID
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# inspired by vLLM
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_WAITING_TIME = 300
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
def __init__(self, args):
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "colossalai.inference.server.api_server"] + args,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get("http://localhost:8000/v0/models").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_WAITING_TIME:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote(
|
||||
[
|
||||
"--model",
|
||||
"/home/chenjianghai/data/llama-7b-hf",
|
||||
]
|
||||
)
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
async def test_completion(server):
|
||||
data = {"prompt": "How are you?", "stream": "False"}
|
||||
response = await server.post("v1/completion", json=data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
async def test_chat(server):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
]
|
||||
data = {"messages": messages, "stream": "False"}
|
||||
response = await server.post("v1/chat", data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
Loading…
Reference in New Issue