|
|
|
@ -20,10 +20,12 @@ from fastapi import FastAPI, Request
|
|
|
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse |
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.inference.config import InferenceConfig |
|
|
|
|
from colossalai.inference.server.chat_service import ChatServing |
|
|
|
|
from colossalai.inference.server.completion_service import CompletionServing |
|
|
|
|
from colossalai.inference.server.utils import id_generator |
|
|
|
|
from colossalai.inference.utils import find_available_ports |
|
|
|
|
|
|
|
|
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa |
|
|
|
|
|
|
|
|
@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
|
|
|
|
|
""" |
|
|
|
|
request_dict = await request.json() |
|
|
|
|
prompt = request_dict.pop("prompt") |
|
|
|
|
stream = request_dict.pop("stream", "false").lower() |
|
|
|
|
|
|
|
|
|
stream = request_dict.pop("stream", "false") |
|
|
|
|
if isinstance(stream, str): |
|
|
|
|
stream = stream.lower() |
|
|
|
|
request_id = id_generator() |
|
|
|
|
generation_config = get_generation_config(request_dict) |
|
|
|
|
results = engine.generate(request_id, prompt, generation_config=generation_config) |
|
|
|
@ -66,7 +69,7 @@ async def generate(request: Request) -> Response:
|
|
|
|
|
ret = {"text": request_output[len(prompt) :]} |
|
|
|
|
yield (json.dumps(ret) + "\0").encode("utf-8") |
|
|
|
|
|
|
|
|
|
if stream == "true": |
|
|
|
|
if stream == "true" or stream == True: |
|
|
|
|
return StreamingResponse(stream_results()) |
|
|
|
|
|
|
|
|
|
# Non-streaming case |
|
|
|
@ -86,12 +89,14 @@ async def generate(request: Request) -> Response:
|
|
|
|
|
@app.post("/completion") |
|
|
|
|
async def create_completion(request: Request): |
|
|
|
|
request_dict = await request.json() |
|
|
|
|
stream = request_dict.pop("stream", "false").lower() |
|
|
|
|
stream = request_dict.pop("stream", "false") |
|
|
|
|
if isinstance(stream, str): |
|
|
|
|
stream = stream.lower() |
|
|
|
|
generation_config = get_generation_config(request_dict) |
|
|
|
|
result = await completion_serving.create_completion(request, generation_config) |
|
|
|
|
|
|
|
|
|
ret = {"request_id": result.request_id, "text": result.output} |
|
|
|
|
if stream == "true": |
|
|
|
|
if stream == "true" or stream == True: |
|
|
|
|
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") |
|
|
|
|
else: |
|
|
|
|
return JSONResponse(content=ret) |
|
|
|
@ -101,10 +106,12 @@ async def create_completion(request: Request):
|
|
|
|
|
async def create_chat(request: Request): |
|
|
|
|
request_dict = await request.json() |
|
|
|
|
|
|
|
|
|
stream = request_dict.get("stream", "false").lower() |
|
|
|
|
stream = request_dict.get("stream", "false") |
|
|
|
|
if isinstance(stream, str): |
|
|
|
|
stream = stream.lower() |
|
|
|
|
generation_config = get_generation_config(request_dict) |
|
|
|
|
message = await chat_serving.create_chat(request, generation_config) |
|
|
|
|
if stream == "true": |
|
|
|
|
if stream == "true" or stream == True: |
|
|
|
|
return StreamingResponse(content=message, media_type="text/event-stream") |
|
|
|
|
else: |
|
|
|
|
ret = {"role": message.role, "text": message.content} |
|
|
|
@ -115,27 +122,29 @@ def get_generation_config(request):
|
|
|
|
|
generation_config = async_engine.engine.generation_config |
|
|
|
|
for arg in request: |
|
|
|
|
if hasattr(generation_config, arg): |
|
|
|
|
generation_config[arg] = request[arg] |
|
|
|
|
setattr(generation_config, arg, request[arg]) |
|
|
|
|
return generation_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_engine_config(parser): |
|
|
|
|
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use") |
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
|
"--max-model-len", |
|
|
|
|
type=int, |
|
|
|
|
default=None, |
|
|
|
|
help="model context length. If unspecified, " "will be automatically derived from the model.", |
|
|
|
|
"-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use" |
|
|
|
|
) |
|
|
|
|
# Parallel arguments |
|
|
|
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") |
|
|
|
|
# Parallel arguments not supported now |
|
|
|
|
|
|
|
|
|
# KV cache arguments |
|
|
|
|
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") |
|
|
|
|
|
|
|
|
|
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") |
|
|
|
|
|
|
|
|
|
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length") |
|
|
|
|
|
|
|
|
|
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length") |
|
|
|
|
|
|
|
|
|
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) |
|
|
|
|
|
|
|
|
|
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") |
|
|
|
|
|
|
|
|
|
# generation arguments |
|
|
|
|
parser.add_argument( |
|
|
|
|
"--prompt_template", |
|
|
|
@ -150,7 +159,7 @@ def parse_args():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Colossal-Inference API server.") |
|
|
|
|
|
|
|
|
|
parser.add_argument("--host", type=str, default="127.0.0.1") |
|
|
|
|
parser.add_argument("--port", type=int, default=8000) |
|
|
|
|
parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.") |
|
|
|
|
parser.add_argument("--ssl-keyfile", type=str, default=None) |
|
|
|
|
parser.add_argument("--ssl-certfile", type=str, default=None) |
|
|
|
|
parser.add_argument( |
|
|
|
@ -164,6 +173,7 @@ def parse_args():
|
|
|
|
|
"specified, the model name will be the same as " |
|
|
|
|
"the huggingface name.", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
|
"--chat-template", |
|
|
|
|
type=str, |
|
|
|
@ -184,13 +194,21 @@ def parse_args():
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
args = parse_args() |
|
|
|
|
inference_config = InferenceConfig.from_dict(vars(args)) |
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model) |
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model) |
|
|
|
|
colossalai_backend_port = find_available_ports(1)[0] |
|
|
|
|
colossalai.launch( |
|
|
|
|
rank=0, |
|
|
|
|
world_size=1, |
|
|
|
|
host=args.host, |
|
|
|
|
port=colossalai_backend_port, |
|
|
|
|
backend="nccl", |
|
|
|
|
) |
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model) |
|
|
|
|
async_engine = AsyncInferenceEngine( |
|
|
|
|
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config |
|
|
|
|
start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config |
|
|
|
|
) |
|
|
|
|
engine = async_engine.engine |
|
|
|
|
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) |
|
|
|
|
completion_serving = CompletionServing(async_engine, model.__class__.__name__) |
|
|
|
|
chat_serving = ChatServing( |
|
|
|
|
async_engine, |
|
|
|
|
served_model=model.__class__.__name__, |
|
|
|
|