import argparse
import logging
import random
from typing import Optional
import uvicorn
from batch import BatchManagerForGeneration
from cache import ListCache , MissCacheError
from energonai import QueueFullError , launch_engine
from energonai . model import opt_6B , opt_30B , opt_125M , opt_175B
from fastapi import FastAPI , HTTPException , Request
from pydantic import BaseModel , Field
from transformers import GPT2Tokenizer
class GenerationTaskReq ( BaseModel ) :
max_tokens : int = Field ( gt = 0 , le = 256 , example = 64 )
prompt : str = Field (
min_length = 1 ,
example = " Question: Where were the 2004 Olympics held? \n Answer: Athens, Greece \n \n Question: What is the longest river on the earth? \n Answer: " ,
)
top_k : Optional [ int ] = Field ( default = None , gt = 0 , example = 50 )
top_p : Optional [ float ] = Field ( default = None , gt = 0.0 , lt = 1.0 , example = 0.5 )
temperature : Optional [ float ] = Field ( default = None , gt = 0.0 , lt = 1.0 , example = 0.7 )
app = FastAPI ( )
@app.post ( " /generation " )
async def generate ( data : GenerationTaskReq , request : Request ) :
logger . info ( f ' { request . client . host } : { request . client . port } - " { request . method } { request . url . path } " - { data } ' )
key = ( data . prompt , data . max_tokens )
try :
if cache is None :
raise MissCacheError ( )
outputs = cache . get ( key )
output = random . choice ( outputs )
logger . info ( " Cache hit " )
except MissCacheError :
inputs = tokenizer ( data . prompt , truncation = True , max_length = 512 )
inputs [ " max_tokens " ] = data . max_tokens
inputs [ " top_k " ] = data . top_k
inputs [ " top_p " ] = data . top_p
inputs [ " temperature " ] = data . temperature
try :
uid = id ( data )
engine . submit ( uid , inputs )
output = await engine . wait ( uid )
output = tokenizer . decode ( output , skip_special_tokens = True )
if cache is not None :
cache . add ( key , output )
except QueueFullError as e :
raise HTTPException ( status_code = 406 , detail = e . args [ 0 ] )
return { " text " : output }
@app.on_event ( " shutdown " )
async def shutdown ( * _ ) :
engine . shutdown ( )
server . should_exit = True
server . force_exit = True
await server . shutdown ( )
def get_model_fn ( model_name : str ) :
model_map = { " opt-125m " : opt_125M , " opt-6.7b " : opt_6B , " opt-30b " : opt_30B , " opt-175b " : opt_175B }
return model_map [ model_name ]
def print_args ( args : argparse . Namespace ) :
print ( " \n ==> Args: " )
for k , v in args . __dict__ . items ( ) :
print ( f " { k } = { v } " )
FIXED_CACHE_KEYS = [
(
" Question: What is the name of the largest continent on earth? \n Answer: Asia \n \n Question: What is at the center of the solar system? \n Answer: " ,
64 ,
) ,
(
" A chat between a salesman and a student. \n \n Salesman: Hi boy, are you looking for a new phone? \n Student: Yes, my phone is not functioning well. \n Salesman: What is your budget? \n Student: I have received my scholarship so I am fine with any phone. \n Salesman: Great, then perhaps this latest flagship phone is just right for you. " ,
64 ,
) ,
(
" English: I am happy today. \n Chinese: 我今天很开心。 \n \n English: I am going to play basketball. \n Chinese: 我一会去打篮球。 \n \n English: Let ' s celebrate our anniversary. \n Chinese: " ,
64 ,
) ,
]
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " model " , choices = [ " opt-125m " , " opt-6.7b " , " opt-30b " , " opt-175b " ] )
parser . add_argument ( " --tp " , type = int , default = 1 )
parser . add_argument ( " --master_host " , default = " localhost " )
parser . add_argument ( " --master_port " , type = int , default = 19990 )
parser . add_argument ( " --rpc_port " , type = int , default = 19980 )
parser . add_argument ( " --max_batch_size " , type = int , default = 8 )
parser . add_argument ( " --pipe_size " , type = int , default = 1 )
parser . add_argument ( " --queue_size " , type = int , default = 0 )
parser . add_argument ( " --http_host " , default = " 0.0.0.0 " )
parser . add_argument ( " --http_port " , type = int , default = 7070 )
parser . add_argument ( " --checkpoint " , default = None )
parser . add_argument ( " --cache_size " , type = int , default = 0 )
parser . add_argument ( " --cache_list_size " , type = int , default = 1 )
args = parser . parse_args ( )
print_args ( args )
model_kwargs = { }
if args . checkpoint is not None :
model_kwargs [ " checkpoint " ] = args . checkpoint
logger = logging . getLogger ( __name__ )
tokenizer = GPT2Tokenizer . from_pretrained ( " facebook/opt-30b " )
if args . cache_size > 0 :
cache = ListCache ( args . cache_size , args . cache_list_size , fixed_keys = FIXED_CACHE_KEYS )
else :
cache = None
engine = launch_engine (
args . tp ,
1 ,
args . master_host ,
args . master_port ,
args . rpc_port ,
get_model_fn ( args . model ) ,
batch_manager = BatchManagerForGeneration (
max_batch_size = args . max_batch_size , pad_token_id = tokenizer . pad_token_id
) ,
pipe_size = args . pipe_size ,
queue_size = args . queue_size ,
* * model_kwargs ,
)
config = uvicorn . Config ( app , host = args . http_host , port = args . http_port )
server = uvicorn . Server ( config = config )
server . run ( )