from contextlib import nullcontext
from typing import Optional

import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout
from transformers import AutoModelForCausalLM, AutoTokenizer

from colossalai.nn.optimizer import HybridAdam


@ray.remote
class PPOConsumer(BaseConsumer):
    def __init__(
        self,
        num_producers,
        num_episodes,
        rank,
        world_size,
        master_addr,
        master_port,
        num_update_per_episode,
        num_recv_per_update,
        batch_size,
        model_config,
        plugin_config,
        microbatch_size=1,
        num_generations=1,
        gamma:float=1.0,
        lam:float=0.95,
        kl_coef:float=0.05,
        use_wandb=True,
    ):
        super().__init__(
            num_producers,
            num_episodes,
            rank,
            world_size,
            master_addr,
            master_port,
            num_update_per_episode,
            num_recv_per_update,
            batch_size,
            model_config,
            plugin_config,
            microbatch_size,
        )
        self.gamma = gamma
        self.lam = lam
        self.kl_coef = kl_coef
        
        path = model_config.pop("path")
        self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
        self.policy_model.train()
        self.policy_model.gradient_checkpointing_enable()
        self.critic_model = Critic(path, **model_config)
        self.critic_model.model.gradient_checkpointing_enable()
        self.critic_model.train()
        
        # Disable dropout
        disable_dropout(self.policy_model)
        disable_dropout(self.critic_model)
        
        self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
        self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
        
        self.accum_loss = torch.zeros(1, device=self.device)
        self.accum_reward = torch.zeros(1, device=self.device)
        self.accum_kl = torch.zeros(1, device=self.device)
        self.accum_advantage = torch.zeros(1, device=self.device)
        self.accum_critic_loss = torch.zeros(1, device=self.device)
        self.accum_count = 0

        # Reference model is initialized from policy model.
        self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
        self.reference_model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.pad_token_id = self.tokenizer.pad_token_id
        self.num_generations = num_generations

        # Initialize verifiable reward.
        response_format_tags = {
            "think_start": {"text": "<think>", "num_occur": 1},
            "think_end": {"text": "</think>", "num_occur": 1},
            "answer_start": {"text": "<answer>", "num_occur": 1},
            "answer_end": {"text": "</answer>", "num_occur": 1},
        }
        self.reward_model = VerifiableReward(
            reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
        )

        self.policy_loss_fn = PolicyLoss()
        self.critic_loss_fn = ValueLoss()
        self.global_step = 0
        if use_wandb and self.rank == 0:
            self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)

    def setup(self):
        super().setup()
        self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
        self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
            self.critic_model, self.critic_optimizer
        )
        self.reference_model, *_ = self.booster.boost(self.reference_model)

    def step(self, step_idx: int, **kwargs) -> Optional[float]:
        """
        Step data from policy model:
            [{
                "input_ids": torch.Tensor,
                "attention_mask": torch.Tensor,
                "action_mask": torch.Tensor,
                "action_log_probs": torch.Tensor,
            },
            ...]
        Format:
            [batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
        """

        # Reshape to [batch_size x num_of_generation, prompt_length + response_length]
        data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
        action_mask = data["action_mask"]
        num_action = action_mask.shape[1]
        old_action_log_probs = data["action_log_probs"].detach()

        need_update = (step_idx + 1) % self.num_microbatches == 0

        ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
        with ctx:
            policy_model_logits = self.policy_model(
                input_ids=data["input_ids"],
                attention_mask=data["attention_mask"],
            )["logits"]
            action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)

            with torch.no_grad():
                reference_model_logits = self.reference_model(
                    input_ids=data["input_ids"],
                    attention_mask=data["attention_mask"],
                )["logits"]
            reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)

            value = self.critic_model(
                input_ids=data["input_ids"],
                attention_mask=data["attention_mask"],
            )
            value = value[:, -num_action -1: -1] * action_mask

            r = self.reward_model(
                data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
            )
            reward, kl = compute_reward_ppo(
                r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
            )
                
            # Calculate advantages
            # reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
            lastgaelam = 0
            advantage_reversed = []
            for t in reversed(range(num_action)):
                nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
                delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
                lastgaelam = delta + self.gamma * self.lam * lastgaelam
                advantage_reversed.append(lastgaelam)
            advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
            advantage = advantage.detach()
            
            # KL divergence for logging
            per_token_kl = (
                torch.exp(reference_action_log_probs - action_log_probs)
                - (reference_action_log_probs - action_log_probs)
                - 1
            )
            kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)

            # Calculate Loss
            loss, skip_update, _ = self.policy_loss_fn(
                action_log_probs,
                old_action_log_probs,
                advantage,
                0,  # kl is already included in the advantage
                action_mask,
            )
            
            # Critic Loss
            # Hack: use the current value to approximate the old value, should be old value mathematically
            critic_loss = self.critic_loss_fn(
                value,
                value.detach(),
                advantage,
                action_mask=action_mask,
            )

            if not skip_update:
                self.booster.backward(loss, self.optimizer)
                self.critic_booster.backward(critic_loss, self.critic_optimizer)
            loss = all_reduce_mean(loss, self.plugin)
            critic_loss = all_reduce_mean(critic_loss, self.plugin)
            r_mean = all_reduce_mean(r.mean(), self.plugin)
            kl = all_reduce_mean(kl.mean(), self.plugin)
            advantage = all_reduce_mean(advantage.mean(), self.plugin)
            self.accum_loss.add_(loss.data)
            self.accum_critic_loss.add_(critic_loss.data)
            self.accum_advantage.add_(advantage.data)
            self.accum_reward.add_(r_mean.data)
            self.accum_kl.add_(kl.data)
            self.accum_count += 1
        if self.rank == 0:
            print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
        if need_update:
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.critic_optimizer.step()
            self.critic_optimizer.zero_grad()
            loss_scalar = self.accum_loss.item()
            if self.rank == 0:
                print(
                    "Loss:",
                    self.accum_loss.item() / self.accum_count,
                    "Reward:",
                    self.accum_reward.item() / self.accum_count,
                    "KL:",
                    self.accum_kl.item() / self.accum_count,
                )
                if self.global_step % 3 == 0:
                    for i in range(min(3, data["input_ids"].shape[0])):
                        response_decoded_for_logging = self.tokenizer.decode(
                            data["input_ids"][i], skip_special_tokens=True
                        )
                        response_reward_for_logging = r[i]
                        print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
                self.wandb_run.log(
                    {
                        "train/loss": self.accum_loss.item() / self.accum_count,
                        "train/reward": self.accum_reward.item() / self.accum_count,
                        "train/kl": self.accum_kl.item() / self.accum_count,
                        "train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
                        "train/advantage": self.accum_advantage.item() / self.accum_count,
                    }
                )
            self.accum_loss.zero_()
            self.accum_reward.zero_()
            self.accum_kl.zero_()
            self.accum_advantage.zero_()
            self.accum_critic_loss.zero_()
            self.accum_count = 0
            self.global_step += 1
            return loss_scalar

    def state_dict(self):
        self.policy_model._force_wait_all_gather()
        model = self.policy_model.unwrap()
        state_dict = model.state_dict()
        return state_dict