from torch import nn from ._utils import calc_acc from .accuracy_2d import Accuracy2D from .accuracy_2p5d import Accuracy2p5D from .accuracy_3d import Accuracy3D from colossalai.nn.layer.utils import get_tensor_parallel_mode _parallel_accuracy = { '2d': Accuracy2D, '2.5d': Accuracy2p5D, '3d': Accuracy3D, } class Accuracy(nn.Module): def __init__(self): super().__init__() tensor_parallel = get_tensor_parallel_mode() if tensor_parallel not in _parallel_accuracy: self.acc = calc_acc else: self.acc = _parallel_accuracy[tensor_parallel]() def forward(self, *args): return self.acc(*args)