support linear accumulation fusion (#5199)

support linear accumulation fusion

support linear accumulation fusion

fix
pull/5217/head
flybird11111 2023-12-29 18:22:42 +08:00 committed by GitHub
parent 64519eb830
commit 02d2328a04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 5 deletions

View File

@ -7,6 +7,12 @@ try:
except:
fused_mix_prec_layer_norm_cuda = None
try:
import fused_weight_gradient_mlp_cuda
_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()

View File

@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim