mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
704 B
27 lines
704 B
from torch import nn
|
|
|
|
from ._utils import calc_acc
|
|
from .accuracy_2d import Accuracy2D
|
|
from .accuracy_2p5d import Accuracy2p5D
|
|
from .accuracy_3d import Accuracy3D
|
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
|
|
|
_parallel_accuracy = {
|
|
'2d': Accuracy2D,
|
|
'2.5d': Accuracy2p5D,
|
|
'3d': Accuracy3D,
|
|
}
|
|
|
|
|
|
class Accuracy(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
tensor_parallel = get_tensor_parallel_mode()
|
|
if tensor_parallel not in _parallel_accuracy:
|
|
self.acc = calc_acc
|
|
else:
|
|
self.acc = _parallel_accuracy[tensor_parallel]()
|
|
|
|
def forward(self, *args):
|
|
return self.acc(*args)
|