2022-11-11 09:08:17 +00:00
import argparse
2023-09-19 06:20:26 +00:00
import logging
2022-11-11 09:08:17 +00:00
import random
from typing import Optional
2023-09-19 06:20:26 +00:00
from batch import BatchManagerForGeneration
from cache import ListCache , MissCacheError
from energonai import QueueFullError , launch_engine
from energonai . model import opt_6B , opt_30B , opt_125M , opt_175B
from pydantic import BaseModel , Field
2022-11-11 09:08:17 +00:00
from sanic import Sanic
from sanic . request import Request
from sanic . response import json
2023-09-19 06:20:26 +00:00
from sanic_ext import openapi , validate
from torch import Tensor
from transformers import GPT2Tokenizer
2022-11-11 09:08:17 +00:00
class GenerationTaskReq ( BaseModel ) :
max_tokens : int = Field ( gt = 0 , le = 256 , example = 64 )
prompt : str = Field (
2023-09-19 06:20:26 +00:00
min_length = 1 ,
example = " Question: Where were the 2004 Olympics held? \n Answer: Athens, Greece \n \n Question: What is the longest river on the earth? \n Answer: " ,
)
2022-11-11 09:08:17 +00:00
top_k : Optional [ int ] = Field ( default = None , gt = 0 , example = 50 )
top_p : Optional [ float ] = Field ( default = None , gt = 0.0 , lt = 1.0 , example = 0.5 )
temperature : Optional [ float ] = Field ( default = None , gt = 0.0 , lt = 1.0 , example = 0.7 )
2023-09-19 06:20:26 +00:00
app = Sanic ( " opt " )
2022-11-11 09:08:17 +00:00
2023-09-19 06:20:26 +00:00
@app.post ( " /generation " )
2022-11-11 09:08:17 +00:00
@openapi.body ( GenerationTaskReq )
@validate ( json = GenerationTaskReq )
async def generate ( request : Request , body : GenerationTaskReq ) :
logger . info ( f ' { request . ip } : { request . port } - " { request . method } { request . path } " - { body } ' )
key = ( body . prompt , body . max_tokens )
try :
if cache is None :
raise MissCacheError ( )
outputs = cache . get ( key )
output = random . choice ( outputs )
2023-09-19 06:20:26 +00:00
logger . info ( " Cache hit " )
2022-11-11 09:08:17 +00:00
except MissCacheError :
inputs = tokenizer ( body . prompt , truncation = True , max_length = 512 )
2023-09-19 06:20:26 +00:00
inputs [ " max_tokens " ] = body . max_tokens
inputs [ " top_k " ] = body . top_k
inputs [ " top_p " ] = body . top_p
inputs [ " temperature " ] = body . temperature
2022-11-11 09:08:17 +00:00
try :
uid = id ( body )
engine . submit ( uid , inputs )
output = await engine . wait ( uid )
assert isinstance ( output , Tensor )
output = tokenizer . decode ( output , skip_special_tokens = True )
if cache is not None :
cache . add ( key , output )
except QueueFullError as e :
2023-09-19 06:20:26 +00:00
return json ( { " detail " : e . args [ 0 ] } , status = 406 )
2022-11-11 09:08:17 +00:00
2023-09-19 06:20:26 +00:00
return json ( { " text " : output } )
2022-11-11 09:08:17 +00:00
@app.after_server_stop
def shutdown ( * _ ) :
engine . shutdown ( )
def get_model_fn ( model_name : str ) :
2023-09-19 06:20:26 +00:00
model_map = { " opt-125m " : opt_125M , " opt-6.7b " : opt_6B , " opt-30b " : opt_30B , " opt-175b " : opt_175B }
2022-11-11 09:08:17 +00:00
return model_map [ model_name ]
def print_args ( args : argparse . Namespace ) :
2023-09-19 06:20:26 +00:00
print ( " \n ==> Args: " )
2022-11-11 09:08:17 +00:00
for k , v in args . __dict__ . items ( ) :
2023-09-19 06:20:26 +00:00
print ( f " { k } = { v } " )
2022-11-11 09:08:17 +00:00
FIXED_CACHE_KEYS = [
2023-09-19 06:20:26 +00:00
(
" Question: What is the name of the largest continent on earth? \n Answer: Asia \n \n Question: What is at the center of the solar system? \n Answer: " ,
64 ,
) ,
(
" A chat between a salesman and a student. \n \n Salesman: Hi boy, are you looking for a new phone? \n Student: Yes, my phone is not functioning well. \n Salesman: What is your budget? \n Student: I have received my scholarship so I am fine with any phone. \n Salesman: Great, then perhaps this latest flagship phone is just right for you. " ,
64 ,
) ,
(
" English: I am happy today. \n Chinese: 我今天很开心。 \n \n English: I am going to play basketball. \n Chinese: 我一会去打篮球。 \n \n English: Let ' s celebrate our anniversary. \n Chinese: " ,
64 ,
) ,
2022-11-11 09:08:17 +00:00
]
2023-09-19 06:20:26 +00:00
if __name__ == " __main__ " :
2022-11-11 09:08:17 +00:00
parser = argparse . ArgumentParser ( )
2023-09-19 06:20:26 +00:00
parser . add_argument ( " model " , choices = [ " opt-125m " , " opt-6.7b " , " opt-30b " , " opt-175b " ] )
parser . add_argument ( " --tp " , type = int , default = 1 )
parser . add_argument ( " --master_host " , default = " localhost " )
parser . add_argument ( " --master_port " , type = int , default = 19990 )
parser . add_argument ( " --rpc_port " , type = int , default = 19980 )
parser . add_argument ( " --max_batch_size " , type = int , default = 8 )
parser . add_argument ( " --pipe_size " , type = int , default = 1 )
parser . add_argument ( " --queue_size " , type = int , default = 0 )
parser . add_argument ( " --http_host " , default = " 0.0.0.0 " )
parser . add_argument ( " --http_port " , type = int , default = 7070 )
parser . add_argument ( " --checkpoint " , default = None )
parser . add_argument ( " --cache_size " , type = int , default = 0 )
parser . add_argument ( " --cache_list_size " , type = int , default = 1 )
2022-11-11 09:08:17 +00:00
args = parser . parse_args ( )
print_args ( args )
model_kwargs = { }
if args . checkpoint is not None :
2023-09-19 06:20:26 +00:00
model_kwargs [ " checkpoint " ] = args . checkpoint
2022-11-11 09:08:17 +00:00
logger = logging . getLogger ( __name__ )
2023-09-19 06:20:26 +00:00
tokenizer = GPT2Tokenizer . from_pretrained ( " facebook/opt-30b " )
2022-11-11 09:08:17 +00:00
if args . cache_size > 0 :
cache = ListCache ( args . cache_size , args . cache_list_size , fixed_keys = FIXED_CACHE_KEYS )
else :
cache = None
2023-09-19 06:20:26 +00:00
engine = launch_engine (
args . tp ,
1 ,
args . master_host ,
args . master_port ,
args . rpc_port ,
get_model_fn ( args . model ) ,
batch_manager = BatchManagerForGeneration (
max_batch_size = args . max_batch_size , pad_token_id = tokenizer . pad_token_id
) ,
pipe_size = args . pipe_size ,
queue_size = args . queue_size ,
* * model_kwargs ,
)
2022-11-11 09:08:17 +00:00
app . run ( args . http_host , args . http_port )