mirror of https://github.com/hpcaitech/ColossalAI
support linear accumulation fusion (#5199)
support linear accumulation fusion support linear accumulation fusion fixpull/5217/head
parent
64519eb830
commit
02d2328a04
|
@ -7,6 +7,12 @@ try:
|
|||
except:
|
||||
fused_mix_prec_layer_norm_cuda = None
|
||||
|
||||
try:
|
||||
import fused_weight_gradient_mlp_cuda
|
||||
_grad_accum_fusion_available = True
|
||||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
|
@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
|
@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
# grad_weight = grad_output.t().matmul(input_parallel)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
|
@ -388,7 +431,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
grad_weight = input_parallel.t().matmul(grad_output)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
|
|
@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
|
|
Loading…
Reference in New Issue