import argparse import os from threading import Lock from typing import Generator, List, Optional import torch import uvicorn from coati.quant import llama_load_quant, low_resource_init from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." MAX_LEN = 512 running_lock = Lock() class GenerationTaskReq(BaseModel): max_new_tokens: int = Field(gt=0, le=512, example=64) history: List[Dialogue] = Field(min_items=1) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2) limiter = Limiter(key_func=get_remote_address) app = FastAPI() app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # set CORS origin_spec_from_env = os.environ.get("CORS_ORIGIN", None) if origin_spec_from_env is not None: # allow CORS from the specified origins origins = os.environ["CORS_ORIGIN"].split(",") else: # allow CORS from all origins origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { "max_generate_tokens": max_new_tokens, "early_stopping": True, "top_k": top_k, "top_p": top_p, "temperature": temperature, "prepare_inputs_fn": model.prepare_inputs_for_generation, "update_model_kwargs_fn": update_model_kwargs_fn, } is_first_word = True generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) for output in generator: output = output.cpu() tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True) current_sub_tokens = [] for token in tokens: if token in tokenizer.all_special_tokens: continue current_sub_tokens.append(token) if current_sub_tokens: out_string = tokenizer.sp_model.decode(current_sub_tokens) if is_first_word: out_string = out_string.lstrip() is_first_word = False elif current_sub_tokens[0].startswith("▁"): # whitespace will be ignored by the frontend out_string = " " + out_string yield out_string async def event_generator(request: Request, generator: Generator): while True: if await request.is_disconnected(): break try: yield {"event": "generate", "data": next(generator)} except StopIteration: yield {"event": "end", "data": ""} break @app.post("/generate/stream") @limiter.limit("1/second") def generate(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) event_source = event_generator( request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature) ) return EventSourceResponse(event_source) @app.post("/generate") @limiter.limit("1/second") def generate_no_stream(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) if prompt_processor.has_censored_words(prompt): return prompt_processor.SAFE_RESPONSE inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} with running_lock: output = model.generate(**inputs, **data.dict(exclude={"history"})) output = output.cpu() prompt_len = inputs["input_ids"].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = prompt_processor.postprocess_output(out_string) if prompt_processor.has_censored_words(out_string): return prompt_processor.SAFE_RESPONSE return out_string if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "pretrained", help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", ) parser.add_argument( "--quant", choices=["8bit", "4bit"], default=None, help="Quantization mode. Default: None (no quantization, fp16).", ) parser.add_argument( "--gptq_checkpoint", default=None, help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", ) parser.add_argument( "--gptq_group_size", type=int, default=128, help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", ) parser.add_argument("--http_host", default="0.0.0.0") parser.add_argument("--http_port", type=int, default=7070) parser.add_argument( "--profanity_file", default=None, help="Path to profanity words list. It should be a JSON file containing a list of words.", ) args = parser.parse_args() if args.quant == "4bit": assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) if args.profanity_file is not None: censored_words = load_json(args.profanity_file) else: censored_words = [] prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) if args.quant == "4bit": with low_resource_init(): config = LlamaConfig.from_pretrained(args.pretrained) model = LlamaForCausalLM(config) model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: model = LlamaForCausalLM.from_pretrained( args.pretrained, load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) if args.quant != "8bit": model.half() # seems to fix bugs for some users. model.eval() config = uvicorn.Config(app, host=args.http_host, port=args.http_port) server = uvicorn.Server(config=config) server.run()