diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c4ef48634..bfe408065 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -767,12 +767,12 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): - return _reduce(input_, process_group) + def forward(ctx, input_, process_group, fp8_communication=False): + return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class _ReduceBackward(torch.autograd.Function):