mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
350 lines
13 KiB
350 lines
13 KiB
3 years ago
|
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
from typing import Any, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from colossalai.communication import all_gather, reduce_scatter, scatter
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.utils import empty_cache, get_current_device
|
||
|
from torch import Tensor
|
||
|
|
||
|
|
||
|
class Matmul_AB_3D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = AB`
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
depth: int,
|
||
|
input_parallel_mode: ParallelMode,
|
||
|
weight_parallel_mode: ParallelMode,
|
||
|
output_parallel_mode: ParallelMode,
|
||
|
input_dim: int = 0,
|
||
|
weight_dim: int = -1,
|
||
|
output_dim: int = 0) -> Tensor:
|
||
|
# A: [m/q^2, n, k/q]
|
||
|
# B: [k/q, h/q^2]
|
||
|
# C: [m/q^2, n, h/q]
|
||
|
empty_cache()
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
assert A.shape[-1] == B.shape[0], \
|
||
|
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
|
||
|
|
||
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||
|
|
||
|
C = torch.matmul(A_temp, B_temp)
|
||
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.A_group_parallel_mode = input_parallel_mode
|
||
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
||
|
ctx.C_group_parallel_mode = output_parallel_mode
|
||
|
ctx.A_dim = input_dim
|
||
|
ctx.B_dim = weight_dim
|
||
|
ctx.C_dim = output_dim
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
with torch.no_grad():
|
||
|
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.B_group_parallel_mode,
|
||
|
ctx.A_group_parallel_mode, ctx.C_dim,
|
||
|
ctx.B_dim, ctx.A_dim)
|
||
|
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
|
||
|
ctx.A_group_parallel_mode,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.B_group_parallel_mode, ctx.A_dim,
|
||
|
ctx.C_dim, ctx.B_dim)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Matmul_ABT_3D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = AB^T`
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
depth: int,
|
||
|
input_parallel_mode: ParallelMode,
|
||
|
weight_parallel_mode: ParallelMode,
|
||
|
output_parallel_mode: ParallelMode,
|
||
|
input_dim: int = 0,
|
||
|
weight_dim: int = -1,
|
||
|
output_dim: int = 0) -> Tensor:
|
||
|
# A: [m/q^2, n, h/q]
|
||
|
# B: [k/q, h/q^2]
|
||
|
# C: [m/q^2, n, k/q]
|
||
|
empty_cache()
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||
|
|
||
|
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
|
||
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.A_group_parallel_mode = input_parallel_mode
|
||
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
||
|
ctx.C_group_parallel_mode = output_parallel_mode
|
||
|
ctx.A_dim = input_dim
|
||
|
ctx.B_dim = weight_dim
|
||
|
ctx.C_dim = output_dim
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
with torch.no_grad():
|
||
|
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.B_group_parallel_mode,
|
||
|
ctx.A_group_parallel_mode, ctx.C_dim,
|
||
|
ctx.B_dim, ctx.A_dim)
|
||
|
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.A_group_parallel_mode,
|
||
|
ctx.B_group_parallel_mode, ctx.C_dim,
|
||
|
ctx.A_dim, ctx.B_dim)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Matmul_ATB_3D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = A^TB`
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
depth: int,
|
||
|
input_parallel_mode: ParallelMode,
|
||
|
weight_parallel_mode: ParallelMode,
|
||
|
output_parallel_mode: ParallelMode,
|
||
|
input_dim: int = 0,
|
||
|
weight_dim: int = 0,
|
||
|
output_dim: int = -1) -> Tensor:
|
||
|
# A: [m/q^2, n, k/q]
|
||
|
# B: [m/q^2, n, h/q]
|
||
|
# C: [k/q, h/q^2]
|
||
|
empty_cache()
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||
|
A_temp = A_temp.reshape(-1, A.shape[-1])
|
||
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||
|
B_temp = B_temp.reshape(-1, B.shape[-1])
|
||
|
|
||
|
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
|
||
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.A_group_parallel_mode = input_parallel_mode
|
||
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
||
|
ctx.C_group_parallel_mode = output_parallel_mode
|
||
|
ctx.A_dim = input_dim
|
||
|
ctx.B_dim = weight_dim
|
||
|
ctx.C_dim = output_dim
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
with torch.no_grad():
|
||
|
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
|
||
|
ctx.B_group_parallel_mode,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.A_group_parallel_mode, ctx.B_dim,
|
||
|
ctx.C_dim, ctx.A_dim)
|
||
|
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
|
||
|
ctx.A_group_parallel_mode,
|
||
|
ctx.C_group_parallel_mode,
|
||
|
ctx.B_group_parallel_mode, ctx.A_dim,
|
||
|
ctx.C_dim, ctx.B_dim)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Add_3D(torch.autograd.Function):
|
||
|
"""Matrix add bias: :math:`C = A + b`
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||
|
input_parallel_mode: ParallelMode,
|
||
|
weight_parallel_mode: ParallelMode,
|
||
|
output_parallel_mode: ParallelMode) -> Tensor:
|
||
|
# input: [m/q^2, n, h/q]
|
||
|
# bias: [h/q^2]
|
||
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||
|
bias_temp = bias.clone()
|
||
|
dist.broadcast(bias_temp,
|
||
|
src=src_rank,
|
||
|
group=gpc.get_group(input_parallel_mode))
|
||
|
# [h/q]
|
||
|
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
||
|
|
||
|
out = input_ + bias_temp
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.src_rank = src_rank
|
||
|
ctx.A_group_parallel_mode = input_parallel_mode
|
||
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
||
|
ctx.C_group_parallel_mode = output_parallel_mode
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
# output_grad: [m/q^2, n, h/q]
|
||
|
with torch.no_grad():
|
||
|
# [h/q]
|
||
|
grad = torch.sum(output_grad,
|
||
|
dim=tuple(range(len(output_grad.shape))[:-1]))
|
||
|
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
||
|
dist.reduce(bias_grad,
|
||
|
dst=ctx.src_rank,
|
||
|
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||
|
if gpc.get_local_rank(
|
||
|
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||
|
ctx.C_group_parallel_mode):
|
||
|
bias_grad = None
|
||
|
return output_grad, bias_grad, None, None, None, None
|
||
|
|
||
|
|
||
|
class Mul_3D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = A * b`
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||
|
input_parallel_mode: ParallelMode,
|
||
|
weight_parallel_mode: ParallelMode,
|
||
|
output_parallel_mode: ParallelMode) -> Tensor:
|
||
|
# input: [m/q^2, n, h/q]
|
||
|
# bias: [h/q^2]
|
||
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||
|
# [h/q^2]
|
||
|
bias_temp = bias.clone()
|
||
|
dist.broadcast(bias_temp,
|
||
|
src=src_rank,
|
||
|
group=gpc.get_group(input_parallel_mode))
|
||
|
# [h/q]
|
||
|
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
||
|
|
||
|
empty_cache()
|
||
|
ctx.save_for_backward(input_, bias_temp)
|
||
|
|
||
|
out = torch.mul(input_, bias_temp)
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.src_rank = src_rank
|
||
|
ctx.A_group_parallel_mode = input_parallel_mode
|
||
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
||
|
ctx.C_group_parallel_mode = output_parallel_mode
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
# output_grad: [m/q^2, n, h/q]
|
||
|
with torch.no_grad():
|
||
|
input_, bias = ctx.saved_tensors
|
||
|
# [m/q^2, n, h/q]
|
||
|
input_grad = torch.mul(output_grad, bias)
|
||
|
# [h/q]
|
||
|
grad = torch.mul(output_grad, input_)
|
||
|
grad = torch.sum(grad,
|
||
|
dim=tuple(range(len(output_grad.shape))[:-1]))
|
||
|
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
||
|
dist.reduce(bias_grad,
|
||
|
dst=ctx.src_rank,
|
||
|
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||
|
if gpc.get_local_rank(
|
||
|
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||
|
ctx.C_group_parallel_mode):
|
||
|
bias_grad = None
|
||
|
return input_grad, bias_grad, None, None, None, None
|
||
|
|
||
|
|
||
|
class Sum_3D(torch.autograd.Function):
|
||
|
"""Compute the sum of input tensors
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
input_: Tensor,
|
||
|
dim: int,
|
||
|
depth: int,
|
||
|
parallel_mode: ParallelMode,
|
||
|
keepdim: bool = False) -> Tensor:
|
||
|
# input: [m/q^2, n, h/q]
|
||
|
out = torch.sum(input_, dim=dim, keepdim=keepdim)
|
||
|
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
|
||
|
|
||
|
ctx.input_shape = input_.shape
|
||
|
ctx.depth = depth
|
||
|
ctx.group = parallel_mode
|
||
|
ctx.dim = dim
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
with torch.no_grad():
|
||
|
output_grad = output_grad.contiguous()
|
||
|
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
|
||
|
if len(output_grad.shape) < len(ctx.input_shape):
|
||
|
output_grad = torch.unsqueeze(output_grad, ctx.dim)
|
||
|
dims = [1 for _ in range(len(output_grad.shape))]
|
||
|
dims[ctx.dim] = ctx.input_shape[ctx.dim]
|
||
|
input_grad = output_grad.repeat(tuple(dims))
|
||
|
return input_grad, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Reduce_3D(torch.autograd.Function):
|
||
|
"""Reduce input tensors
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, input_: Tensor, depth: int,
|
||
|
parallel_mode: ParallelMode) -> Tensor:
|
||
|
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||
|
return input_.clone()
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
return output_grad, None, None
|
||
|
|
||
|
|
||
|
class Slice_3D(torch.autograd.Function):
|
||
|
"""Slice input tensor
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
|
||
|
parallel_mode: ParallelMode) -> Tensor:
|
||
|
rank = gpc.get_local_rank(parallel_mode)
|
||
|
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
|
||
|
|
||
|
ctx.depth = depth
|
||
|
ctx.parallel_mode = parallel_mode
|
||
|
ctx.dim = dim
|
||
|
ctx.input_shape = input_.shape
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
with torch.no_grad():
|
||
|
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
||
|
input_grad.reshape(ctx.input_shape)
|
||
|
return input_grad, None, None, None
|