From f7acfa1bd5aa648cf3cf0e00005265c9b37870dd Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 20 Aug 2024 05:07:58 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/_operation.py | 5 ++++- colossalai/shardformer/layer/qkv_fused_linear.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bfe408065..ed7b35233 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -767,11 +767,14 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, fp8_communication=False): + def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False): + ctx.grad_scale = grad_scale return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale return grad_output, None, None diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 561867993..f9a41a467 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -555,7 +555,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group, self.fp8_communication) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(