2021-12-27 07:04:32 +00:00
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
from ._utils import calc_acc
|
|
|
|
from .accuracy_2d import Accuracy2D
|
|
|
|
from .accuracy_2p5d import Accuracy2p5D
|
|
|
|
from .accuracy_3d import Accuracy3D
|
2021-12-29 15:32:10 +00:00
|
|
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
2021-12-27 07:04:32 +00:00
|
|
|
|
|
|
|
_parallel_accuracy = {
|
|
|
|
'2d': Accuracy2D,
|
|
|
|
'2.5d': Accuracy2p5D,
|
|
|
|
'3d': Accuracy3D,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class Accuracy(nn.Module):
|
2021-12-29 15:32:10 +00:00
|
|
|
def __init__(self):
|
2021-12-27 07:04:32 +00:00
|
|
|
super().__init__()
|
2021-12-29 15:32:10 +00:00
|
|
|
tensor_parallel = get_tensor_parallel_mode()
|
2022-02-14 03:15:02 +00:00
|
|
|
if tensor_parallel not in _parallel_accuracy:
|
2021-12-27 07:04:32 +00:00
|
|
|
self.acc = calc_acc
|
|
|
|
else:
|
|
|
|
self.acc = _parallel_accuracy[tensor_parallel]()
|
|
|
|
|
|
|
|
def forward(self, *args):
|
|
|
|
return self.acc(*args)
|