From 02636c5bef5cc9ae9c6b2c0e38e4e53e28e47060 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 19 Aug 2024 02:26:52 +0000 Subject: [PATCH] fix the merge --- colossalai/shardformer/layer/_operation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c4ef48634..bfe408065 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -767,12 +767,12 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): - return _reduce(input_, process_group) + def forward(ctx, input_, process_group, fp8_communication=False): + return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class _ReduceBackward(torch.autograd.Function):