mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
5.0 KiB
137 lines
5.0 KiB
import argparse
|
|
import logging
|
|
import random
|
|
from typing import Optional
|
|
|
|
import uvicorn
|
|
from batch import BatchManagerForGeneration
|
|
from cache import ListCache, MissCacheError
|
|
from energonai import QueueFullError, launch_engine
|
|
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from pydantic import BaseModel, Field
|
|
from transformers import GPT2Tokenizer
|
|
|
|
|
|
class GenerationTaskReq(BaseModel):
|
|
max_tokens: int = Field(gt=0, le=256, example=64)
|
|
prompt: str = Field(
|
|
min_length=1,
|
|
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
|
|
)
|
|
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
|
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
|
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.post("/generation")
|
|
async def generate(data: GenerationTaskReq, request: Request):
|
|
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
|
key = (data.prompt, data.max_tokens)
|
|
try:
|
|
if cache is None:
|
|
raise MissCacheError()
|
|
outputs = cache.get(key)
|
|
output = random.choice(outputs)
|
|
logger.info("Cache hit")
|
|
except MissCacheError:
|
|
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
|
inputs["max_tokens"] = data.max_tokens
|
|
inputs["top_k"] = data.top_k
|
|
inputs["top_p"] = data.top_p
|
|
inputs["temperature"] = data.temperature
|
|
try:
|
|
uid = id(data)
|
|
engine.submit(uid, inputs)
|
|
output = await engine.wait(uid)
|
|
output = tokenizer.decode(output, skip_special_tokens=True)
|
|
if cache is not None:
|
|
cache.add(key, output)
|
|
except QueueFullError as e:
|
|
raise HTTPException(status_code=406, detail=e.args[0])
|
|
|
|
return {"text": output}
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown(*_):
|
|
engine.shutdown()
|
|
server.should_exit = True
|
|
server.force_exit = True
|
|
await server.shutdown()
|
|
|
|
|
|
def get_model_fn(model_name: str):
|
|
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
|
|
return model_map[model_name]
|
|
|
|
|
|
def print_args(args: argparse.Namespace):
|
|
print("\n==> Args:")
|
|
for k, v in args.__dict__.items():
|
|
print(f"{k} = {v}")
|
|
|
|
|
|
FIXED_CACHE_KEYS = [
|
|
(
|
|
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
|
|
64,
|
|
),
|
|
(
|
|
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
|
|
64,
|
|
),
|
|
(
|
|
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
|
|
64,
|
|
),
|
|
]
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
|
|
parser.add_argument("--tp", type=int, default=1)
|
|
parser.add_argument("--master_host", default="localhost")
|
|
parser.add_argument("--master_port", type=int, default=19990)
|
|
parser.add_argument("--rpc_port", type=int, default=19980)
|
|
parser.add_argument("--max_batch_size", type=int, default=8)
|
|
parser.add_argument("--pipe_size", type=int, default=1)
|
|
parser.add_argument("--queue_size", type=int, default=0)
|
|
parser.add_argument("--http_host", default="0.0.0.0")
|
|
parser.add_argument("--http_port", type=int, default=7070)
|
|
parser.add_argument("--checkpoint", default=None)
|
|
parser.add_argument("--cache_size", type=int, default=0)
|
|
parser.add_argument("--cache_list_size", type=int, default=1)
|
|
args = parser.parse_args()
|
|
print_args(args)
|
|
model_kwargs = {}
|
|
if args.checkpoint is not None:
|
|
model_kwargs["checkpoint"] = args.checkpoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
|
|
if args.cache_size > 0:
|
|
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
|
else:
|
|
cache = None
|
|
engine = launch_engine(
|
|
args.tp,
|
|
1,
|
|
args.master_host,
|
|
args.master_port,
|
|
args.rpc_port,
|
|
get_model_fn(args.model),
|
|
batch_manager=BatchManagerForGeneration(
|
|
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
|
|
),
|
|
pipe_size=args.pipe_size,
|
|
queue_size=args.queue_size,
|
|
**model_kwargs,
|
|
)
|
|
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
|
server = uvicorn.Server(config=config)
|
|
server.run()
|