ColossalAI/examples/tutorial/opt/inference/opt_fastapi.py

137 lines
5.0 KiB
Python
Raw Normal View History

import argparse
import logging
import random
from typing import Optional
import uvicorn
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1,
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
)
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = FastAPI()
@app.post("/generation")
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
key = (data.prompt, data.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info("Cache hit")
except MissCacheError:
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
inputs["max_tokens"] = data.max_tokens
inputs["top_k"] = data.top_k
inputs["top_p"] = data.top_p
inputs["temperature"] = data.temperature
try:
uid = id(data)
engine.submit(uid, inputs)
output = await engine.wait(uid)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
raise HTTPException(status_code=406, detail=e.args[0])
return {"text": output}
@app.on_event("shutdown")
async def shutdown(*_):
engine.shutdown()
server.should_exit = True
server.force_exit = True
await server.shutdown()
def get_model_fn(model_name: str):
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print("\n==> Args:")
for k, v in args.__dict__.items():
print(f"{k} = {v}")
FIXED_CACHE_KEYS = [
(
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
64,
),
(
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
64,
),
(
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
64,
),
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--master_host", default="localhost")
parser.add_argument("--master_port", type=int, default=19990)
parser.add_argument("--rpc_port", type=int, default=19980)
parser.add_argument("--max_batch_size", type=int, default=8)
parser.add_argument("--pipe_size", type=int, default=1)
parser.add_argument("--queue_size", type=int, default=0)
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=7070)
parser.add_argument("--checkpoint", default=None)
parser.add_argument("--cache_size", type=int, default=0)
parser.add_argument("--cache_list_size", type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs["checkpoint"] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(
args.tp,
1,
args.master_host,
args.master_port,
args.rpc_port,
get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs,
)
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()