diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py index 1137d1963..a86832973 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/nn/metric/accuracy_2d.py @@ -8,6 +8,7 @@ from ._utils import calc_acc class Accuracy2D(nn.Module): """Accuracy for 2D parallelism """ + def __init__(self): super().__init__()