mirror of https://github.com/hpcaitech/ColossalAI
536 lines
20 KiB
Python
536 lines
20 KiB
Python
|
from typing import Any, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import Tensor
|
||
|
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.utils import get_current_device, empty_cache
|
||
|
|
||
|
|
||
|
def get_parallel_group(parallel_mode: ParallelMode):
|
||
|
return gpc.get_group(parallel_mode)
|
||
|
|
||
|
|
||
|
def get_global_rank():
|
||
|
return gpc.get_global_rank()
|
||
|
|
||
|
|
||
|
def get_parallel_rank(parallel_mode: ParallelMode):
|
||
|
return gpc.get_local_rank(parallel_mode)
|
||
|
|
||
|
|
||
|
class Matmul_AB_2p5D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = AB`
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
tesseract_dim: int,
|
||
|
tesseract_dep: int,
|
||
|
out_shape: Tuple[int, ...],
|
||
|
row_rank: int,
|
||
|
col_rank: int,
|
||
|
dep_rank: int,
|
||
|
row_parallel_mode: ParallelMode,
|
||
|
col_parallel_mode: ParallelMode,
|
||
|
dep_parallel_mode: ParallelMode,
|
||
|
data_parallel_rank: int,
|
||
|
pipeline_parallel_rank: int,
|
||
|
pipeline_parallel_size: int,
|
||
|
tensor_parallel_size: int) -> Tensor:
|
||
|
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
|
||
|
# B: [h / dq, s / q]
|
||
|
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q]
|
||
|
|
||
|
assert A.shape[-1] == B.shape[-2], \
|
||
|
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
|
||
|
|
||
|
empty_cache()
|
||
|
if ctx:
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
A_shape = A.shape
|
||
|
A = A.reshape((-1, A_shape[-1]))
|
||
|
B_shape = B.shape
|
||
|
B = B.reshape((-1, B_shape[-1]))
|
||
|
C_shape = (A.shape[0], B.shape[-1])
|
||
|
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||
|
|
||
|
for i in range(tesseract_dim):
|
||
|
A_temp = A.clone()
|
||
|
B_temp = B.clone()
|
||
|
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.broadcast(A_temp, src=src_a,
|
||
|
group=get_parallel_group(row_parallel_mode))
|
||
|
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.broadcast(B_temp, src=src_b,
|
||
|
group=get_parallel_group(col_parallel_mode))
|
||
|
torch.addmm(C, A_temp, B_temp, out=C)
|
||
|
|
||
|
out = C.reshape(out_shape)
|
||
|
|
||
|
if ctx:
|
||
|
ctx.tesseract_dim = tesseract_dim
|
||
|
ctx.tesseract_dep = tesseract_dep
|
||
|
ctx.row_rank = row_rank
|
||
|
ctx.col_rank = col_rank
|
||
|
ctx.dep_rank = dep_rank
|
||
|
ctx.row_parallel_mode = row_parallel_mode
|
||
|
ctx.col_parallel_mode = col_parallel_mode
|
||
|
ctx.dep_parallel_mode = dep_parallel_mode
|
||
|
ctx.A_shape = A_shape
|
||
|
ctx.B_shape = B_shape
|
||
|
ctx.data_parallel_rank = data_parallel_rank
|
||
|
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||
|
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||
|
ctx.tensor_parallel_size = tensor_parallel_size
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
A_grad = Matmul_ABT_2p5D.forward(
|
||
|
None,
|
||
|
output_grad, B,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
B_grad = Matmul_ATB_2p5D.forward(
|
||
|
None,
|
||
|
A, output_grad,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Matmul_ABT_2p5D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = AB^T`
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
tesseract_dim: int,
|
||
|
tesseract_dep: int,
|
||
|
out_shape: Tuple[int, ...],
|
||
|
row_rank: int,
|
||
|
col_rank: int,
|
||
|
dep_rank: int,
|
||
|
row_parallel_mode: ParallelMode,
|
||
|
col_parallel_mode: ParallelMode,
|
||
|
dep_parallel_mode: ParallelMode,
|
||
|
data_parallel_rank: int,
|
||
|
pipeline_parallel_rank: int,
|
||
|
pipeline_parallel_size: int,
|
||
|
tensor_parallel_size: int
|
||
|
) -> Tensor:
|
||
|
|
||
|
assert A.shape[-1] == B.shape[-1], \
|
||
|
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
|
||
|
|
||
|
empty_cache()
|
||
|
if ctx:
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
A_shape = A.shape
|
||
|
A = A.reshape((-1, A_shape[-1]))
|
||
|
B_shape = B.shape
|
||
|
B = B.reshape((-1, B_shape[-1]))
|
||
|
C_shape = (A.shape[0], B.shape[0])
|
||
|
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||
|
|
||
|
for i in range(tesseract_dim):
|
||
|
B_temp = B.clone()
|
||
|
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
|
||
|
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
|
||
|
src_c = i + row_rank * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
|
||
|
if i == col_rank:
|
||
|
C = C_temp.clone()
|
||
|
|
||
|
out = C.reshape(out_shape)
|
||
|
|
||
|
if ctx:
|
||
|
ctx.tesseract_dim = tesseract_dim
|
||
|
ctx.tesseract_dep = tesseract_dep
|
||
|
ctx.row_rank = row_rank
|
||
|
ctx.col_rank = col_rank
|
||
|
ctx.dep_rank = dep_rank
|
||
|
ctx.row_parallel_mode = row_parallel_mode
|
||
|
ctx.col_parallel_mode = col_parallel_mode
|
||
|
ctx.dep_parallel_mode = dep_parallel_mode
|
||
|
ctx.A_shape = A_shape
|
||
|
ctx.B_shape = B_shape
|
||
|
ctx.data_parallel_rank = data_parallel_rank
|
||
|
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||
|
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||
|
ctx.tensor_parallel_size = tensor_parallel_size
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
A_grad = Matmul_AB_2p5D.forward(
|
||
|
None,
|
||
|
output_grad, B,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
B_grad = Matmul_ATB_2p5D.forward(
|
||
|
None,
|
||
|
output_grad, A,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Matmul_ATB_2p5D(torch.autograd.Function):
|
||
|
"""Matrix multiplication for :math:`C = A^TB`
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
A: Tensor,
|
||
|
B: Tensor,
|
||
|
tesseract_dim: int,
|
||
|
tesseract_dep: int,
|
||
|
out_shape: Tuple[int, ...],
|
||
|
row_rank: int,
|
||
|
col_rank: int,
|
||
|
dep_rank: int,
|
||
|
row_parallel_mode: ParallelMode,
|
||
|
col_parallel_mode: ParallelMode,
|
||
|
dep_parallel_mode: ParallelMode,
|
||
|
data_parallel_rank: int,
|
||
|
pipeline_parallel_rank: int,
|
||
|
pipeline_parallel_size: int,
|
||
|
tensor_parallel_size: int):
|
||
|
|
||
|
assert A.shape[-2] == B.shape[-2], \
|
||
|
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
|
||
|
|
||
|
empty_cache()
|
||
|
if ctx:
|
||
|
ctx.save_for_backward(A, B)
|
||
|
|
||
|
A_shape = A.shape
|
||
|
A = A.reshape((-1, A_shape[-1]))
|
||
|
B_shape = B.shape
|
||
|
B = B.reshape((-1, B_shape[-1]))
|
||
|
C_shape = (A.shape[-1], B.shape[-1])
|
||
|
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||
|
|
||
|
for i in range(tesseract_dim):
|
||
|
A_temp = A.clone()
|
||
|
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.broadcast(A_temp, src=src_a,
|
||
|
group=get_parallel_group(row_parallel_mode))
|
||
|
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
||
|
src_c = col_rank + i * tesseract_dim + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.reduce(C_temp, dst=src_c,
|
||
|
group=get_parallel_group(col_parallel_mode))
|
||
|
if i == row_rank:
|
||
|
C = C_temp.clone()
|
||
|
|
||
|
out = C.reshape(out_shape)
|
||
|
|
||
|
if ctx:
|
||
|
ctx.tesseract_dim = tesseract_dim
|
||
|
ctx.tesseract_dep = tesseract_dep
|
||
|
ctx.row_rank = row_rank
|
||
|
ctx.col_rank = col_rank
|
||
|
ctx.dep_rank = dep_rank
|
||
|
ctx.row_parallel_mode = row_parallel_mode
|
||
|
ctx.col_parallel_mode = col_parallel_mode
|
||
|
ctx.dep_parallel_mode = dep_parallel_mode
|
||
|
ctx.A_shape = A_shape
|
||
|
ctx.B_shape = B_shape
|
||
|
ctx.data_parallel_rank = data_parallel_rank
|
||
|
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||
|
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||
|
ctx.tensor_parallel_size = tensor_parallel_size
|
||
|
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||
|
A, B = ctx.saved_tensors
|
||
|
A_grad = Matmul_ABT_2p5D.forward(
|
||
|
None,
|
||
|
B, output_grad,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
B_grad = Matmul_AB_2p5D.forward(
|
||
|
None,
|
||
|
A, output_grad,
|
||
|
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
|
||
|
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||
|
ctx.row_parallel_mode,
|
||
|
ctx.col_parallel_mode,
|
||
|
ctx.dep_parallel_mode,
|
||
|
ctx.data_parallel_rank,
|
||
|
ctx.pipeline_parallel_rank,
|
||
|
ctx.pipeline_parallel_size,
|
||
|
ctx.tensor_parallel_size
|
||
|
)
|
||
|
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Add_Bias_2p5D(torch.autograd.Function):
|
||
|
"""Matrix add bias: :math:`C = A + b`
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
input: Tensor,
|
||
|
bias: Tensor,
|
||
|
output_size_per_partition: int,
|
||
|
tesseract_dim: int,
|
||
|
tesseract_dep: int,
|
||
|
row_rank: int,
|
||
|
col_rank: int,
|
||
|
dep_rank: int,
|
||
|
row_parallel_mode: ParallelMode,
|
||
|
col_parallel_mode: ParallelMode,
|
||
|
dep_parallel_mode: ParallelMode,
|
||
|
skip_bias_add: bool,
|
||
|
data_parallel_rank: int,
|
||
|
pipeline_parallel_rank: int,
|
||
|
pipeline_parallel_size: int,
|
||
|
tensor_parallel_size: int
|
||
|
) -> Tensor:
|
||
|
if row_rank == 0:
|
||
|
bias_temp = bias.clone()
|
||
|
else:
|
||
|
bias_temp = torch.zeros(
|
||
|
output_size_per_partition,
|
||
|
dtype=bias.dtype,
|
||
|
device=get_current_device())
|
||
|
src_rank = col_rank + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
||
|
|
||
|
ctx.row_rank = row_rank
|
||
|
ctx.col_rank = col_rank
|
||
|
ctx.dep_rank = dep_rank
|
||
|
ctx.tesseract_dim = tesseract_dim
|
||
|
ctx.tesseract_dep = tesseract_dep
|
||
|
ctx.row_parallel_mode = row_parallel_mode
|
||
|
ctx.col_parallel_mode = col_parallel_mode
|
||
|
ctx.dep_parallel_mode = dep_parallel_mode
|
||
|
ctx.bias = skip_bias_add
|
||
|
ctx.data_parallel_rank = data_parallel_rank
|
||
|
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||
|
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||
|
ctx.tensor_parallel_size = tensor_parallel_size
|
||
|
|
||
|
if skip_bias_add:
|
||
|
return bias_temp
|
||
|
else:
|
||
|
output = input + bias_temp
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, output_grad):
|
||
|
row_rank = ctx.row_rank
|
||
|
col_rank = ctx.col_rank
|
||
|
dep_rank = ctx.dep_rank
|
||
|
tesseract_dim = ctx.tesseract_dim
|
||
|
tesseract_dep = ctx.tesseract_dep
|
||
|
row_parallel_mode = ctx.row_parallel_mode
|
||
|
col_parallel_mode = ctx.col_parallel_mode
|
||
|
dep_parallel_mode = ctx.dep_parallel_mode
|
||
|
data_parallel_rank = ctx.data_parallel_rank
|
||
|
pipeline_parallel_rank = ctx.pipeline_parallel_rank
|
||
|
pipeline_parallel_size = ctx.pipeline_parallel_size
|
||
|
tensor_parallel_size = ctx.tensor_parallel_size
|
||
|
|
||
|
if ctx.bias:
|
||
|
dst_rank = col_rank + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
||
|
if row_rank == 0:
|
||
|
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
else:
|
||
|
grad_tmp = torch.zeros_like(output_grad)
|
||
|
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
else:
|
||
|
reduce_dim = tuple(range(output_grad.ndim - 1))
|
||
|
reduce = torch.sum(output_grad, dim=reduce_dim)
|
||
|
dst_rank = col_rank + dep_rank * (
|
||
|
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||
|
pipeline_parallel_rank * tensor_parallel_size
|
||
|
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
||
|
if row_rank == 0:
|
||
|
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
else:
|
||
|
reduce_tmp = torch.zeros_like(reduce)
|
||
|
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class _LayerNorm_2p5D(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any,
|
||
|
input: Tensor,
|
||
|
E_x: Tensor,
|
||
|
Var_x: Tensor,
|
||
|
hidden_size: int,
|
||
|
row_parallel_mode: ParallelMode,
|
||
|
col_parallel_mode: ParallelMode,
|
||
|
dep_parallel_mode: ParallelMode) -> Tensor:
|
||
|
input = input - E_x
|
||
|
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
||
|
ctx.hidden_size = hidden_size
|
||
|
output = input * Var_x
|
||
|
ctx.save_for_backward(output, Var_x)
|
||
|
ctx.row_parallel_mode = row_parallel_mode
|
||
|
ctx.col_parallel_mode = col_parallel_mode
|
||
|
ctx.dep_parallel_mode = dep_parallel_mode
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, output_grad):
|
||
|
row_parallel_mode = ctx.row_parallel_mode
|
||
|
col_parallel_mode = ctx.col_parallel_mode
|
||
|
dep_parallel_mode = ctx.dep_parallel_mode
|
||
|
x, Var_x = ctx.saved_tensors
|
||
|
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
|
||
|
with torch.no_grad():
|
||
|
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
|
||
|
torch.distributed.all_reduce(
|
||
|
output_grad_sum, group=get_parallel_group(row_parallel_mode))
|
||
|
output_grad_sum /= ctx.hidden_size
|
||
|
|
||
|
output_grad_mul_x_sum = torch.sum(
|
||
|
output_grad * x, dim=-1, keepdim=True)
|
||
|
torch.distributed.all_reduce(
|
||
|
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
|
||
|
output_grad_mul_x_sum /= ctx.hidden_size
|
||
|
|
||
|
input_grad = output_grad.clone()
|
||
|
input_grad -= x * output_grad_mul_x_sum
|
||
|
input_grad -= output_grad_sum
|
||
|
input_grad *= Var_x
|
||
|
|
||
|
return input_grad, None, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class Sum_2p5D(torch.autograd.Function):
|
||
|
"""Compute the sum of input tensors
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx,
|
||
|
inputs,
|
||
|
dim,
|
||
|
tesseract_dim,
|
||
|
row_parallel_mode,
|
||
|
keepdim=False):
|
||
|
# input: [b/q, s, h/q]
|
||
|
empty_cache()
|
||
|
ctx.save_for_backward(inputs)
|
||
|
# sum: [b/q, s]
|
||
|
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
|
||
|
torch.distributed.all_reduce(
|
||
|
out, group=gpc.get_group(row_parallel_mode))
|
||
|
return out
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, output_grad):
|
||
|
with torch.no_grad():
|
||
|
inputs = ctx.saved_tensors
|
||
|
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
|
||
|
return input_grad, None, None, None, None, None
|
||
|
|
||
|
|
||
|
class _ViT_Split_2p5D(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, inputs, batch_size,
|
||
|
tesseract_dim, tesseract_dep,
|
||
|
xz_parallel_mode):
|
||
|
# inputs: [b, s, h/q]
|
||
|
# output: [b/dq, s, h/q]
|
||
|
empty_cache()
|
||
|
|
||
|
ctx.batch_size = batch_size
|
||
|
ctx.tesseract_dim = tesseract_dim
|
||
|
ctx.tesseract_dep = tesseract_dep
|
||
|
ctx.xz_parallel_mode = xz_parallel_mode
|
||
|
xz_rank = gpc.get_local_rank(xz_parallel_mode)
|
||
|
output = torch.chunk(inputs, tesseract_dep *
|
||
|
tesseract_dim, dim=0)[xz_rank]
|
||
|
output = output.clone()
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, output_grad):
|
||
|
# output_grad: [b/dq, s, h/q]
|
||
|
# grads: [b, s, h/q]
|
||
|
# *
|
||
|
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||
|
grads = torch.empty(grads_shape,
|
||
|
dtype=output_grad.dtype,
|
||
|
device=get_current_device())
|
||
|
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
|
||
|
output_grad.contiguous(),
|
||
|
group=get_parallel_group(ctx.xz_parallel_mode))
|
||
|
return grads, None, None, None, None
|