mirror of https://github.com/hpcaitech/ColossalAI
support linear accumulation fusion (#5199)
support linear accumulation fusion support linear accumulation fusion fixpull/5217/head
parent
64519eb830
commit
02d2328a04
|
@ -7,6 +7,12 @@ try:
|
||||||
except:
|
except:
|
||||||
fused_mix_prec_layer_norm_cuda = None
|
fused_mix_prec_layer_norm_cuda = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fused_weight_gradient_mlp_cuda
|
||||||
|
_grad_accum_fusion_available = True
|
||||||
|
except ImportError:
|
||||||
|
_grad_accum_fusion_available = False
|
||||||
|
|
||||||
|
|
||||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||||
r"""Layernorm
|
r"""Layernorm
|
||||||
|
@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
grad = weight.grad
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype == torch.float16:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
|
@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
# reduce-scatter scheduled first and have GPU resources allocated
|
# reduce-scatter scheduled first and have GPU resources allocated
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
grad = weight.grad
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype == torch.float16:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
# calculate gradient
|
# calculate gradient
|
||||||
if len(input_parallel.shape) > 2:
|
if len(input_parallel.shape) > 2:
|
||||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||||
grad_weight = grad_output.t().matmul(input_parallel)
|
|
||||||
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
grad = weight.grad
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype == torch.float16:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
|
# grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
# wait until reduce-scatter finished
|
# wait until reduce-scatter finished
|
||||||
reducescatter_handle.wait()
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
|
|
|
@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
output_parallel = F.linear(input_, self.weight)
|
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
|
||||||
if self.seq_parallel:
|
if self.seq_parallel:
|
||||||
output = linear_reducescatter_forward_gather_backward(
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim
|
output_parallel, self.process_group, self.seq_parallel_dim
|
||||||
|
|
Loading…
Reference in New Issue