import argparse
import os
from threading import Lock
from typing import Generator, List, Optional

import torch
import uvicorn
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn

CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512
running_lock = Lock()


class GenerationTaskReq(BaseModel):
    max_new_tokens: int = Field(gt=0, le=512, example=64)
    history: List[Dialogue] = Field(min_items=1)
    top_k: Optional[int] = Field(default=None, gt=0, example=50)
    top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
    temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
    repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)


limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# set CORS
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)

if origin_spec_from_env is not None:
    # allow CORS from the specified origins
    origins = os.environ['CORS_ORIGIN'].split(',')
else:
    # allow CORS from all origins
    origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
    # TODO(ver217): streaming generation does not support repetition_penalty now
    model_kwargs = {
        'max_generate_tokens': max_new_tokens,
        'early_stopping': True,
        'top_k': top_k,
        'top_p': top_p,
        'temperature': temperature,
        'prepare_inputs_fn': model.prepare_inputs_for_generation,
        'update_model_kwargs_fn': update_model_kwargs_fn,
    }
    is_first_word = True
    generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
    for output in generator:
        output = output.cpu()
        tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
        current_sub_tokens = []
        for token in tokens:
            if token in tokenizer.all_special_tokens:
                continue
            current_sub_tokens.append(token)
        if current_sub_tokens:
            out_string = tokenizer.sp_model.decode(current_sub_tokens)
            if is_first_word:
                out_string = out_string.lstrip()
                is_first_word = False
            elif current_sub_tokens[0].startswith('▁'):
                # whitespace will be ignored by the frontend
                out_string = ' ' + out_string
            yield out_string


async def event_generator(request: Request, generator: Generator):
    while True:
        if await request.is_disconnected():
            break
        try:
            yield {'event': 'generate', 'data': next(generator)}
        except StopIteration:
            yield {'event': 'end', 'data': ''}
            break


@app.post('/generate/stream')
@limiter.limit('1/second')
def generate(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
    event_source = event_generator(
        request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
    return EventSourceResponse(event_source)


@app.post('/generate')
@limiter.limit('1/second')
def generate_no_stream(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
    if prompt_processor.has_censored_words(prompt):
        return prompt_processor.SAFE_RESPONSE
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
    with running_lock:
        output = model.generate(**inputs, **data.dict(exclude={'history'}))
    output = output.cpu()
    prompt_len = inputs['input_ids'].size(1)
    response = output[0, prompt_len:]
    out_string = tokenizer.decode(response, skip_special_tokens=True)
    out_string = prompt_processor.postprocess_output(out_string)
    if prompt_processor.has_censored_words(out_string):
        return prompt_processor.SAFE_RESPONSE
    return out_string


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'pretrained',
        help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
    parser.add_argument('--quant',
                        choices=['8bit', '4bit'],
                        default=None,
                        help='Quantization mode. Default: None (no quantization, fp16).')
    parser.add_argument(
        '--gptq_checkpoint',
        default=None,
        help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
    parser.add_argument('--gptq_group_size',
                        type=int,
                        default=128,
                        help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
    parser.add_argument('--http_host', default='0.0.0.0')
    parser.add_argument('--http_port', type=int, default=7070)
    parser.add_argument('--profanity_file',
                        default=None,
                        help='Path to profanity words list. It should be a JSON file containing a list of words.')
    args = parser.parse_args()

    if args.quant == '4bit':
        assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained)

    if args.profanity_file is not None:
        censored_words = load_json(args.profanity_file)
    else:
        censored_words = []
    prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)

    if args.quant == '4bit':
        with low_resource_init():
            config = LlamaConfig.from_pretrained(args.pretrained)
            model = LlamaForCausalLM(config)
        model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
        model.cuda()
    else:
        model = LlamaForCausalLM.from_pretrained(
            args.pretrained,
            load_in_8bit=(args.quant == '8bit'),
            torch_dtype=torch.float16,
            device_map="auto",
        )
        if args.quant != '8bit':
            model.half()    # seems to fix bugs for some users.
        model.eval()

    config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
    server = uvicorn.Server(config=config)
    server.run()