from typing import Any, Callable, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


try:
    from transformers.generation_logits_process import (
        LogitsProcessorList,
        TemperatureLogitsWarper,
        TopKLogitsWarper,
        TopPLogitsWarper,
    )
except ImportError:
    from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper


def prepare_logits_processor(top_k: Optional[int] = None,
                             top_p: Optional[float] = None,
                             temperature: Optional[float] = None) -> LogitsProcessorList:
    processor_list = LogitsProcessorList()
    if temperature is not None and temperature != 1.0:
        processor_list.append(TemperatureLogitsWarper(temperature))
    if top_k is not None and top_k != 0:
        processor_list.append(TopKLogitsWarper(top_k))
    if top_p is not None and top_p < 1.0:
        processor_list.append(TopPLogitsWarper(top_p))
    return processor_list


def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
    if dist.is_initialized() and dist.get_world_size() > 1:
        # consider DP
        unfinished_sequences = unfinished_sequences.clone()
        dist.all_reduce(unfinished_sequences)
    return unfinished_sequences.max() == 0


def sample(model: nn.Module,
           input_ids: torch.Tensor,
           max_length: int,
           early_stopping: bool = False,
           eos_token_id: Optional[int] = None,
           pad_token_id: Optional[int] = None,
           top_k: Optional[int] = None,
           top_p: Optional[float] = None,
           temperature: Optional[float] = None,
           prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
           update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
           **model_kwargs) -> torch.Tensor:
    if input_ids.size(1) >= max_length:
        return input_ids

    logits_processor = prepare_logits_processor(top_k, top_p, temperature)
    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

    for _ in range(input_ids.size(1), max_length):
        model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
            if prepare_inputs_fn is not None else {'input_ids': input_ids}
        outputs = model(**model_inputs)

        next_token_logits = outputs['logits'][:, -1, :]
        # pre-process distribution
        next_token_logits = logits_processor(input_ids, next_token_logits)
        # sample
        probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # finished sentences should have their next token be a padding token
        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # update generated ids, model inputs for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if update_model_kwargs_fn is not None:
            model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)

        # if eos_token was found in one sentence, set sentence to finished
        if eos_token_id is not None:
            unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

        # stop when each sentence is finished if early_stopping=True
        if early_stopping and _is_sequence_finished(unfinished_sequences):
            break

    return input_ids


def generate(model: nn.Module,
             input_ids: torch.Tensor,
             max_length: int,
             num_beams: int = 1,
             do_sample: bool = True,
             early_stopping: bool = False,
             eos_token_id: Optional[int] = None,
             pad_token_id: Optional[int] = None,
             top_k: Optional[int] = None,
             top_p: Optional[float] = None,
             temperature: Optional[float] = None,
             prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
             update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
             **model_kwargs) -> torch.Tensor:
    """Generate token sequence. The returned sequence is input_ids + generated_tokens.

    Args:
        model (nn.Module): model
        input_ids (torch.Tensor): input sequence
        max_length (int): max length of the returned sequence
        num_beams (int, optional): number of beams. Defaults to 1.
        do_sample (bool, optional): whether to do sample. Defaults to True.
        early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
        eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
        pad_token_id (Optional[int], optional): pad token id. Defaults to None.
        top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
        top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
        temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
        prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
        update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
    """
    is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
    is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
    is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
    if is_greedy_gen_mode:
        # run greedy search
        raise NotImplementedError
    elif is_sample_gen_mode:
        # run sample
        return sample(model,
                      input_ids,
                      max_length,
                      early_stopping=early_stopping,
                      eos_token_id=eos_token_id,
                      pad_token_id=pad_token_id,
                      top_k=top_k,
                      top_p=top_p,
                      temperature=temperature,
                      prepare_inputs_fn=prepare_inputs_fn,
                      update_model_kwargs_fn=update_model_kwargs_fn,
                      **model_kwargs)
    elif is_beam_gen_mode:
        raise NotImplementedError
    else:
        raise ValueError("Unsupported generation mode")


@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
                        input_ids: torch.Tensor,
                        return_action_mask: bool = True,
                        **kwargs
                        ) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
                                   Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
    """Generate token sequence with actor model. Refer to `generate` for more details.
    """
    # generate sequences
    sequences = generate(actor_model, input_ids, **kwargs)

    # calculate auxiliary tensors
    attention_mask = None
    pad_token_id = kwargs.get('pad_token_id', None)
    if pad_token_id is not None:
        attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
    if not return_action_mask:
        return sequences, attention_mask, None
    input_len = input_ids.size(1)
    eos_token_id = kwargs.get('eos_token_id', None)
    if eos_token_id is None:
        action_mask = torch.ones_like(sequences, dtype=torch.bool)
    else:
        # left padding may be applied, only mask action
        action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
        action_mask = F.pad(action_mask, (1 + input_len, -1), value=True)    # include eos token and input
    action_mask[:, :input_len] = False
    action_mask = action_mask[:, 1:]
    return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]