mirror of https://github.com/hpcaitech/ColossalAI
28 lines
686 B
Python
28 lines
686 B
Python
from torch import nn
|
|
|
|
from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
|
|
|
|
from ._utils import calc_acc
|
|
from .accuracy_2d import Accuracy2D
|
|
from .accuracy_2p5d import Accuracy2p5D
|
|
from .accuracy_3d import Accuracy3D
|
|
|
|
_parallel_accuracy = {
|
|
"2d": Accuracy2D,
|
|
"2.5d": Accuracy2p5D,
|
|
"3d": Accuracy3D,
|
|
}
|
|
|
|
|
|
class Accuracy(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
tensor_parallel = get_tensor_parallel_mode()
|
|
if tensor_parallel not in _parallel_accuracy:
|
|
self.acc = calc_acc
|
|
else:
|
|
self.acc = _parallel_accuracy[tensor_parallel]()
|
|
|
|
def forward(self, *args):
|
|
return self.acc(*args)
|