""" Doc: Feature: - FastAPI based http server for Colossal-Inference - Completion Service Supported Usage: (for local user) - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \ -H 'Content-Type: application/json' \ -d '{"prompt":"hello, who are you? ","stream":"False"}'` Version: V1.0 """ import argparse import json import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.server.chat_service import ChatServing from colossalai.inference.server.completion_service import CompletionServing from colossalai.inference.server.utils import id_generator from colossalai.inference.utils import find_available_ports from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa TIMEOUT_KEEP_ALIVE = 5 # seconds. prompt_template_choices = ["llama", "vicuna"] async_engine = None chat_serving = None completion_serving = None app = FastAPI() @app.get("/ping") def health_check() -> JSONResponse: """Health Check for server.""" return JSONResponse({"status": "Healthy"}) @app.get("/engine_check") def engine_check() -> bool: """Check if the background loop is running.""" loop_status = async_engine.background_loop_status if loop_status == False: return JSONResponse({"status": "Error"}) return JSONResponse({"status": "Running"}) @app.post("/generate") async def generate(request: Request) -> Response: """Generate completion for the request. NOTE: THIS API IS USED ONLY FOR TESTING, DO NOT USE THIS IF YOU ARE IN ACTUAL APPLICATION. A request should be a JSON object with the following fields: - prompts: the prompts to use for the generation. - stream: whether to stream the results or not. - other fields: """ request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", "false") if isinstance(stream, str): stream = stream.lower() request_id = id_generator() generation_config = get_generation_config(request_dict) results = engine.generate(request_id, prompt, generation_config=generation_config) # Streaming case def stream_results(): for request_output in results: ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") if stream == "true" or stream == True: return StreamingResponse(stream_results()) # Non-streaming case final_output = None for request_output in results: if request.is_disconnected(): # Abort the request if the client disconnects. engine.abort(request_id) return Response(status_code=499) final_output = request_output[len(prompt) :] assert final_output is not None ret = {"text": final_output} return JSONResponse(ret) @app.post("/completion") async def create_completion(request: Request): request_dict = await request.json() stream = request_dict.pop("stream", "false") if isinstance(stream, str): stream = stream.lower() generation_config = get_generation_config(request_dict) result = await completion_serving.create_completion(request, generation_config) ret = {"request_id": result.request_id, "text": result.output} if stream == "true" or stream == True: return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") else: return JSONResponse(content=ret) @app.post("/chat") async def create_chat(request: Request): request_dict = await request.json() stream = request_dict.get("stream", "false") if isinstance(stream, str): stream = stream.lower() generation_config = get_generation_config(request_dict) message = await chat_serving.create_chat(request, generation_config) if stream == "true" or stream == True: return StreamingResponse(content=message, media_type="text/event-stream") else: ret = {"role": message.role, "text": message.content} return ret def get_generation_config(request): generation_config = async_engine.engine.generation_config for arg in request: if hasattr(generation_config, arg): setattr(generation_config, arg, request[arg]) return generation_config def add_engine_config(parser): parser.add_argument( "-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use" ) # Parallel arguments not supported now # KV cache arguments parser.add_argument("--block_size", type=int, default=16, choices=[16, 32], help="token block size") parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length") parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length") parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") # generation arguments parser.add_argument( "--prompt_template", choices=prompt_template_choices, default=None, help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", ) return parser def parse_args(): parser = argparse.ArgumentParser(description="Colossal-Inference API server.") parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.") parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument( "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy" ) parser.add_argument( "--model-name", type=str, default=None, help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.", ) parser.add_argument( "--chat-template", type=str, default=None, help="The file path to the chat template, " "or the template in single-line form " "for the specified model", ) parser.add_argument( "--response-role", type=str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.", ) parser = add_engine_config(parser) return parser.parse_args() if __name__ == "__main__": args = parse_args() inference_config = InferenceConfig.from_dict(vars(args)) tokenizer = AutoTokenizer.from_pretrained(args.model) colossalai_backend_port = find_available_ports(1)[0] colossalai.launch( rank=0, world_size=1, host=args.host, port=colossalai_backend_port, backend="nccl", ) model = AutoModelForCausalLM.from_pretrained(args.model) async_engine = AsyncInferenceEngine( start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config ) engine = async_engine.engine completion_serving = CompletionServing(async_engine, model.__class__.__name__) chat_serving = ChatServing( async_engine, served_model=model.__class__.__name__, tokenizer=tokenizer, response_role=args.response_role, chat_template=args.chat_template, ) app.root_path = args.root_path uvicorn.run( app=app, host=args.host, port=args.port, log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, )