From 02d2328a04bfe05e8519f32c659bde0f4506bcfc Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 29 Dec 2023 18:22:42 +0800 Subject: [PATCH] support linear accumulation fusion (#5199) support linear accumulation fusion support linear accumulation fusion fix --- colossalai/shardformer/layer/_operation.py | 51 ++++++++++++++++++++-- colossalai/shardformer/layer/linear.py | 2 +- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 8fd92a2ed..4bca335c8 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -7,6 +7,12 @@ try: except: fused_mix_prec_layer_norm_cuda = None +try: + import fused_weight_gradient_mlp_cuda + _grad_accum_fusion_available = True +except ImportError: + _grad_accum_fusion_available = False + class FusedLayerNormAffineFunction1D(torch.autograd.Function): r"""Layernorm @@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function): # all-reduce scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 - grad_weight = grad_output.t().matmul(total_input) + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce: @@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # reduce-scatter scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 - grad_weight = grad_output.t().matmul(total_input) + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: @@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # calculate gradient if len(input_parallel.shape) > 2: input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - grad_weight = grad_output.t().matmul(input_parallel) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(input_parallel) + else: + grad_weight = grad_output.t().matmul(input_parallel) + # grad_weight = grad_output.t().matmul(input_parallel) # wait until reduce-scatter finished reducescatter_handle.wait() @@ -388,7 +431,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_parallel = torch.cat(tensor_list, dim=dim).contiguous() # calculate gradient if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) grad_weight = input_parallel.t().matmul(grad_output) # wait until reduce-scatter finished reducescatter_handle.wait() diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 9e6386223..eeb0ef399 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = F.linear(input_, self.weight) + output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) if self.seq_parallel: output = linear_reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim