You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/legacy/nn/metric/__init__.py

28 lines
686 B

from torch import nn
from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
_parallel_accuracy = {
"2d": Accuracy2D,
"2.5d": Accuracy2p5D,
"3d": Accuracy3D,
}
class Accuracy(nn.Module):
def __init__(self):
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel not in _parallel_accuracy:
self.acc = calc_acc
else:
self.acc = _parallel_accuracy[tensor_parallel]()
def forward(self, *args):
return self.acc(*args)