2023-03-28 12:25:36 +00:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
from threading import Lock
|
2023-08-02 02:17:36 +00:00
|
|
|
from typing import Generator, List, Optional
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import uvicorn
|
2023-08-02 02:17:36 +00:00
|
|
|
from coati.quant import llama_load_quant, low_resource_init
|
|
|
|
from fastapi import FastAPI, Request
|
2023-03-28 12:25:36 +00:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
|
|
from slowapi.errors import RateLimitExceeded
|
|
|
|
from slowapi.util import get_remote_address
|
|
|
|
from sse_starlette.sse import EventSourceResponse
|
2023-08-02 02:17:36 +00:00
|
|
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
2023-07-18 10:03:08 +00:00
|
|
|
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
|
2023-03-28 13:20:28 +00:00
|
|
|
MAX_LEN = 512
|
2023-03-28 12:25:36 +00:00
|
|
|
running_lock = Lock()
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationTaskReq(BaseModel):
|
|
|
|
max_new_tokens: int = Field(gt=0, le=512, example=64)
|
|
|
|
history: List[Dialogue] = Field(min_items=1)
|
|
|
|
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
|
|
|
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
|
|
|
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
2023-03-28 17:18:45 +00:00
|
|
|
repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
|
|
app = FastAPI()
|
|
|
|
app.state.limiter = limiter
|
|
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
|
|
|
|
# set CORS
|
2023-09-19 06:20:26 +00:00
|
|
|
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
if origin_spec_from_env is not None:
|
|
|
|
# allow CORS from the specified origins
|
2023-09-19 06:20:26 +00:00
|
|
|
origins = os.environ["CORS_ORIGIN"].split(",")
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
|
|
|
# allow CORS from all origins
|
|
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=origins,
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
|
|
|
|
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
2023-08-02 02:17:36 +00:00
|
|
|
# TODO(ver217): streaming generation does not support repetition_penalty now
|
2023-03-28 12:25:36 +00:00
|
|
|
model_kwargs = {
|
2023-09-19 06:20:26 +00:00
|
|
|
"max_generate_tokens": max_new_tokens,
|
|
|
|
"early_stopping": True,
|
|
|
|
"top_k": top_k,
|
|
|
|
"top_p": top_p,
|
|
|
|
"temperature": temperature,
|
|
|
|
"prepare_inputs_fn": model.prepare_inputs_for_generation,
|
|
|
|
"update_model_kwargs_fn": update_model_kwargs_fn,
|
2023-03-28 12:25:36 +00:00
|
|
|
}
|
|
|
|
is_first_word = True
|
|
|
|
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
|
|
|
|
for output in generator:
|
|
|
|
output = output.cpu()
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
|
|
|
|
current_sub_tokens = []
|
|
|
|
for token in tokens:
|
|
|
|
if token in tokenizer.all_special_tokens:
|
|
|
|
continue
|
|
|
|
current_sub_tokens.append(token)
|
|
|
|
if current_sub_tokens:
|
|
|
|
out_string = tokenizer.sp_model.decode(current_sub_tokens)
|
|
|
|
if is_first_word:
|
|
|
|
out_string = out_string.lstrip()
|
|
|
|
is_first_word = False
|
2023-09-19 06:20:26 +00:00
|
|
|
elif current_sub_tokens[0].startswith("▁"):
|
2023-03-28 12:25:36 +00:00
|
|
|
# whitespace will be ignored by the frontend
|
2023-09-19 06:20:26 +00:00
|
|
|
out_string = " " + out_string
|
2023-03-28 12:25:36 +00:00
|
|
|
yield out_string
|
|
|
|
|
|
|
|
|
|
|
|
async def event_generator(request: Request, generator: Generator):
|
|
|
|
while True:
|
|
|
|
if await request.is_disconnected():
|
|
|
|
break
|
|
|
|
try:
|
2023-09-19 06:20:26 +00:00
|
|
|
yield {"event": "generate", "data": next(generator)}
|
2023-03-28 12:25:36 +00:00
|
|
|
except StopIteration:
|
2023-09-19 06:20:26 +00:00
|
|
|
yield {"event": "end", "data": ""}
|
2023-03-28 12:25:36 +00:00
|
|
|
break
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
@app.post("/generate/stream")
|
|
|
|
@limiter.limit("1/second")
|
2023-03-28 12:25:36 +00:00
|
|
|
def generate(data: GenerationTaskReq, request: Request):
|
|
|
|
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
|
|
|
event_source = event_generator(
|
2023-09-19 06:20:26 +00:00
|
|
|
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
|
|
|
|
)
|
2023-03-28 12:25:36 +00:00
|
|
|
return EventSourceResponse(event_source)
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
@app.post("/generate")
|
|
|
|
@limiter.limit("1/second")
|
2023-03-28 12:25:36 +00:00
|
|
|
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
|
|
|
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
2023-03-28 18:14:35 +00:00
|
|
|
if prompt_processor.has_censored_words(prompt):
|
|
|
|
return prompt_processor.SAFE_RESPONSE
|
2023-03-28 12:25:36 +00:00
|
|
|
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
|
|
|
with running_lock:
|
2023-09-19 06:20:26 +00:00
|
|
|
output = model.generate(**inputs, **data.dict(exclude={"history"}))
|
2023-03-28 12:25:36 +00:00
|
|
|
output = output.cpu()
|
2023-09-19 06:20:26 +00:00
|
|
|
prompt_len = inputs["input_ids"].size(1)
|
2023-03-28 12:25:36 +00:00
|
|
|
response = output[0, prompt_len:]
|
|
|
|
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
2023-03-28 18:14:35 +00:00
|
|
|
out_string = prompt_processor.postprocess_output(out_string)
|
|
|
|
if prompt_processor.has_censored_words(out_string):
|
|
|
|
return prompt_processor.SAFE_RESPONSE
|
|
|
|
return out_string
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2023-03-28 12:25:36 +00:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
2023-09-19 06:20:26 +00:00
|
|
|
"pretrained",
|
|
|
|
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
|
|
|
|
)
|
2023-03-28 12:25:36 +00:00
|
|
|
parser.add_argument(
|
2023-09-19 06:20:26 +00:00
|
|
|
"--quant",
|
|
|
|
choices=["8bit", "4bit"],
|
2023-03-28 12:25:36 +00:00
|
|
|
default=None,
|
2023-09-19 06:20:26 +00:00
|
|
|
help="Quantization mode. Default: None (no quantization, fp16).",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--gptq_checkpoint",
|
|
|
|
default=None,
|
|
|
|
help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--gptq_group_size",
|
|
|
|
type=int,
|
|
|
|
default=128,
|
|
|
|
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
|
|
|
|
)
|
|
|
|
parser.add_argument("--http_host", default="0.0.0.0")
|
|
|
|
parser.add_argument("--http_port", type=int, default=7070)
|
|
|
|
parser.add_argument(
|
|
|
|
"--profanity_file",
|
|
|
|
default=None,
|
|
|
|
help="Path to profanity words list. It should be a JSON file containing a list of words.",
|
|
|
|
)
|
2023-03-28 12:25:36 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if args.quant == "4bit":
|
|
|
|
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
2023-03-28 18:14:35 +00:00
|
|
|
|
|
|
|
if args.profanity_file is not None:
|
|
|
|
censored_words = load_json(args.profanity_file)
|
|
|
|
else:
|
|
|
|
censored_words = []
|
|
|
|
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if args.quant == "4bit":
|
2023-08-02 02:17:36 +00:00
|
|
|
with low_resource_init():
|
|
|
|
config = LlamaConfig.from_pretrained(args.pretrained)
|
|
|
|
model = LlamaForCausalLM(config)
|
|
|
|
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
|
2023-03-28 12:25:36 +00:00
|
|
|
model.cuda()
|
|
|
|
else:
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
|
args.pretrained,
|
2023-09-19 06:20:26 +00:00
|
|
|
load_in_8bit=(args.quant == "8bit"),
|
2023-03-28 12:25:36 +00:00
|
|
|
torch_dtype=torch.float16,
|
|
|
|
device_map="auto",
|
|
|
|
)
|
2023-09-19 06:20:26 +00:00
|
|
|
if args.quant != "8bit":
|
|
|
|
model.half() # seems to fix bugs for some users.
|
2023-03-28 12:25:36 +00:00
|
|
|
model.eval()
|
|
|
|
|
|
|
|
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
|
|
|
server = uvicorn.Server(config=config)
|
|
|
|
server.run()
|