from typing import Optional


class TensorParallelEnv(object):

    _instance = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
        return cls._instance

    def __init__(self, *args, **kwargs):
        self.load(*args, **kwargs)

    def load(self,
             mode: Optional[str] = None,
             vocab_parallel: bool = False,
             parallel_input_1d: bool = False,
             summa_dim: int = None,
             tesseract_dim: int = None,
             tesseract_dep: int = None,
             depth_3d: int = None,
             input_group_3d=None,
             weight_group_3d=None,
             output_group_3d=None):
        self.mode = mode
        self.vocab_parallel = vocab_parallel
        self.parallel_input_1d = parallel_input_1d
        self.summa_dim = summa_dim
        self.tesseract_dim = tesseract_dim
        self.tesseract_dep = tesseract_dep
        self.depth_3d = depth_3d
        self.input_group_3d = input_group_3d
        self.weight_group_3d = weight_group_3d
        self.output_group_3d = output_group_3d        

    def save(self):
        return dict(mode=self.mode,
                    vocab_parallel=self.vocab_parallel,
                    parallel_input_1d=self.parallel_input_1d,
                    summa_dim=self.summa_dim,
                    tesseract_dim=self.tesseract_dim,
                    tesseract_dep=self.tesseract_dep,
                    depth_3d=self.depth_3d,
                    input_group_3d=self.input_group_3d,
                    weight_group_3d=self.weight_group_3d,
                    output_group_3d=self.output_group_3d)


class MoeEnv:
    """Moe enviroment variables.
    """

    def __init__(self):
        self.data_parallel_size = None
        self.model_parallel_size = None
        self.aux_loss = None
        self.enable_cuda = True

    def setup(self, moe_model_size):
        from .core import global_context as gpc
        if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
            raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")

        assert gpc.data_parallel_size % moe_model_size == 0, \
            "The size of data parallel needs to be divided by moe model parallel size"

        self.data_parallel_size = gpc.data_parallel_size // moe_model_size
        self.model_parallel_size = moe_model_size

    def is_initialized(self):
        return self.model_parallel_size is not None

    def set_cuda_false(self):
        self.enable_cuda = False

    def reset_loss(self):
        self.aux_loss = 0

    def add_loss(self, loss):
        self.aux_loss += loss

    def get_loss(self):
        return self.aux_loss


tensor_parallel_env = TensorParallelEnv()

moe_env = MoeEnv()