import argparse
import os
from threading import Lock
from typing import Generator, List, Optional

import torch
import uvicorn
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn

CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()


class GenerationTaskReq(BaseModel):
    max_new_tokens: int = Field(gt=0, le=512, example=64)
    history: List[Dialogue] = Field(min_items=1)
    top_k: Optional[int] = Field(default=None, gt=0, example=50)
    top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
    temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
    repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)


limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# set CORS
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)

if origin_spec_from_env is not None:
    # allow CORS from the specified origins
    origins = os.environ["CORS_ORIGIN"].split(",")
else:
    # allow CORS from all origins
    origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
    # TODO(ver217): streaming generation does not support repetition_penalty now
    model_kwargs = {
        "max_generate_tokens": max_new_tokens,
        "early_stopping": True,
        "top_k": top_k,
        "top_p": top_p,
        "temperature": temperature,
        "prepare_inputs_fn": model.prepare_inputs_for_generation,
        "update_model_kwargs_fn": update_model_kwargs_fn,
    }
    is_first_word = True
    generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
    for output in generator:
        output = output.cpu()
        tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
        current_sub_tokens = []
        for token in tokens:
            if token in tokenizer.all_special_tokens:
                continue
            current_sub_tokens.append(token)
        if current_sub_tokens:
            out_string = tokenizer.sp_model.decode(current_sub_tokens)
            if is_first_word:
                out_string = out_string.lstrip()
                is_first_word = False
            elif current_sub_tokens[0].startswith("▁"):
                # whitespace will be ignored by the frontend
                out_string = " " + out_string
            yield out_string


async def event_generator(request: Request, generator: Generator):
    while True:
        if await request.is_disconnected():
            break
        try:
            yield {"event": "generate", "data": next(generator)}
        except StopIteration:
            yield {"event": "end", "data": ""}
            break


@app.post("/generate/stream")
@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
    event_source = event_generator(
        request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
    )
    return EventSourceResponse(event_source)


@app.post("/generate")
@limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
    if prompt_processor.has_censored_words(prompt):
        return prompt_processor.SAFE_RESPONSE
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
    with running_lock:
        output = model.generate(**inputs, **data.dict(exclude={"history"}))
    output = output.cpu()
    prompt_len = inputs["input_ids"].size(1)
    response = output[0, prompt_len:]
    out_string = tokenizer.decode(response, skip_special_tokens=True)
    out_string = prompt_processor.postprocess_output(out_string)
    if prompt_processor.has_censored_words(out_string):
        return prompt_processor.SAFE_RESPONSE
    return out_string


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "pretrained",
        help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
    )
    parser.add_argument(
        "--quant",
        choices=["8bit", "4bit"],
        default=None,
        help="Quantization mode. Default: None (no quantization, fp16).",
    )
    parser.add_argument(
        "--gptq_checkpoint",
        default=None,
        help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
    )
    parser.add_argument(
        "--gptq_group_size",
        type=int,
        default=128,
        help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
    )
    parser.add_argument("--http_host", default="0.0.0.0")
    parser.add_argument("--http_port", type=int, default=7070)
    parser.add_argument(
        "--profanity_file",
        default=None,
        help="Path to profanity words list. It should be a JSON file containing a list of words.",
    )
    args = parser.parse_args()

    if args.quant == "4bit":
        assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained)

    if args.profanity_file is not None:
        censored_words = load_json(args.profanity_file)
    else:
        censored_words = []
    prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)

    if args.quant == "4bit":
        with low_resource_init():
            config = LlamaConfig.from_pretrained(args.pretrained)
            model = LlamaForCausalLM(config)
        model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
        model.cuda()
    else:
        model = LlamaForCausalLM.from_pretrained(
            args.pretrained,
            load_in_8bit=(args.quant == "8bit"),
            torch_dtype=torch.float16,
            device_map="auto",
        )
        if args.quant != "8bit":
            model.half()  # seems to fix bugs for some users.
        model.eval()

    config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
    server = uvicorn.Server(config=config)
    server.run()