import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.utils import clip_grad_norm_fp32


class ColossalaiOptimizer(Optimizer):

    def __init__(self, optim: Optimizer):
        self.optim = optim

    @property
    def param_groups(self):
        return self.optim.param_groups

    @property
    def defaults(self):
        return self.optim.defaults

    def add_param_group(self, *args, **kwargs):
        return self.optim.add_param_group(*args, **kwargs)

    def step(self, *args, **kwargs):
        return self.optim.step(*args, **kwargs)

    def zero_grad(self, *args, **kwargs):
        self.optim.zero_grad(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        self.optim.load_state_dict(*args, **kwargs)

    def state_dict(self):
        return self.optim.state_dict()

    def backward(self, loss: Tensor):
        loss.backward()

    def backward_by_grad(self, tensor: Tensor, grad: Tensor):
        torch.autograd.backward(tensors=tensor, grad_tensors=grad)

    def clip_grad_norm(self, model: nn.Module, max_norm: float):
        if max_norm > 0.0:
            clip_grad_norm_fp32(model.parameters(), max_norm)