ColossalAI/colossalai/nn/layer/parallel_2d/_operation.py

564 lines
21 KiB
Python

from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def matmul_2d(a,
b,
summa_dim,
out_shape,
row_rank=None,
col_rank=None,
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row, defaults to None
:type row_rank: int, optional
:param col_rank: the rank of column, defaults to None
:type col_rank: int, optional
:param row_parallel_mode: row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW
:type row_parallel_mode: str, optional
:param col_parallel_mode: column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL
:type col_parallel_mode: str, optional
:return: :math:`C = AB`
:rtype: torch.tensor
"""
if row_rank is None:
row_rank = gpc.get_local_rank(col_parallel_mode)
if col_rank is None:
col_rank = gpc.get_local_rank(row_parallel_mode)
data_parallel_rank = 0 if not gpc.is_initialized(
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = summa_dim ** 2
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size
)
class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
# B: [h / q, s / q]
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
for i in range(summa_dim):
src_a = i + summa_dim * row_rank
src_b = i + summa_dim * col_rank
src_a = src_a % summa_dim
src_b = src_b % summa_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
B_temp = B.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank,
group=gpc.get_group(col_parallel_mode))
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.normalized_shape = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=gpc.get_group(row_parallel_mode))
output_grad_sum /= ctx.normalized_shape
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.normalized_shape
input_grad = output_grad.clone()
input_grad -= x * output_grad_mul_x_sum
input_grad -= output_grad_sum
input_grad *= Var_x
return input_grad, None, None, None, None, None
# class Sum_2D(torch.autograd.Function):
#
# @staticmethod
# def forward(ctx: Any,
# inputs: Tensor,
# dim: int,
# summa_dim: int,
# row_parallel_mode: ParallelMode,
# keepdim: bool = False) -> Tensor:
# # input: [b/q, s, h/q]
# empty_cache()
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
# return out
#
# @staticmethod
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
class AllGatherLast(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = summa_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None