mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.7 KiB
97 lines
3.7 KiB
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.core import global_context as gpc
|
|
|
|
try:
|
|
import fused_mix_prec_layer_norm_cuda
|
|
except:
|
|
fused_mix_prec_layer_norm_cuda = None
|
|
|
|
|
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
|
r"""Layernorm
|
|
|
|
Args:
|
|
input: input matrix.
|
|
weight: weight matrix.
|
|
bias: bias matrix.
|
|
normalized_shape: input shape from an expected input of size.
|
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps: a value added to the denominator for numerical stability
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
|
ctx.normalized_shape = normalized_shape
|
|
ctx.eps = eps
|
|
input_ = input.contiguous()
|
|
weight_ = weight.contiguous()
|
|
bias_ = bias.contiguous()
|
|
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
|
bias_, ctx.eps)
|
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
grad_input, grad_weight, grad_bias \
|
|
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
|
grad_output.contiguous(), mean, invvar,
|
|
input_, ctx.normalized_shape,
|
|
weight_, bias_, ctx.eps)
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None
|
|
|
|
|
|
class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
"""
|
|
Linear layer execution with asynchronous communication in backprop.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
|
ctx.save_for_backward(input_, weight)
|
|
ctx.use_bias = bias is not None
|
|
ctx.parallel_mode = parallel_mode
|
|
ctx.async_grad_allreduce = async_grad_allreduce
|
|
|
|
output = torch.matmul(input_, weight.t())
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight = ctx.saved_tensors
|
|
use_bias = ctx.use_bias
|
|
|
|
total_input = input
|
|
grad_input = grad_output.matmul(weight)
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
|
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
|
|
|
if ctx.async_grad_allreduce:
|
|
# Asynchronous all-reduce
|
|
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
|
# Delay the start of weight gradient computation shortly (3us) to have
|
|
# all-reduce scheduled first and have GPU resources allocated
|
|
_ = torch.empty(1, device=grad_output.device) + 1
|
|
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
if ctx.async_grad_allreduce:
|
|
handle.wait()
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
|
|
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|