from torch import nn

from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
from colossalai.nn.layer.utils import get_tensor_parallel_mode

_parallel_accuracy = {
    '2d': Accuracy2D,
    '2.5d': Accuracy2p5D,
    '3d': Accuracy3D,
}


class Accuracy(nn.Module):
    def __init__(self):
        super().__init__()
        tensor_parallel = get_tensor_parallel_mode()
        if tensor_parallel not in _parallel_accuracy:
            self.acc = calc_acc
        else:
            self.acc = _parallel_accuracy[tensor_parallel]()

    def forward(self, *args):
        return self.acc(*args)