import torch.nn as nn
from torch import Tensor

from ..parallel_2d._operation import split_batch_2d
from ..parallel_2p5d._operation import split_batch_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode

_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}


def partition_batch(input_) -> Tensor:
    tensor_parallel_mode = get_tensor_parallel_mode()
    if tensor_parallel_mode in _parallel_split_batch:
        if isinstance(input_, dict):
            return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
        else:
            return _parallel_split_batch[tensor_parallel_mode](input_)
    else:
        return input_


class ColossalaiModule(nn.Module):

    def __init__(self, module: nn.Module, **kwargs):
        super().__init__()
        # copy values
        self.__dict__ = module.__dict__.copy()
        # copy methods
        for name, attr in module.__class__.__dict__.items():
            if name not in ['__init__', 'forward'] and callable(attr):
                setattr(self, name, getattr(module, name))
        self._forward_func = module.forward
        for k, v in kwargs.items():
            setattr(self, k, v)

    def forward(self, *args):
        return self._forward_func(*args)