mirror of https://github.com/hpcaitech/ColossalAI
564 lines
22 KiB
Python
564 lines
22 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Any, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from torch import Tensor
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
|
|
class linear_3d(torch.autograd.Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any,
|
|
input_: Tensor,
|
|
weight: Tensor,
|
|
bias: Optional[Tensor],
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode,
|
|
input_dim: int = 0,
|
|
weight_dim: int = -1,
|
|
output_dim: int = 0) -> Tensor:
|
|
assert input_.shape[-1] == weight.shape[0], \
|
|
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
|
|
|
|
ctx.use_bias = bias is not None
|
|
|
|
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
|
input_ = torch.cat(input_, dim=input_dim)
|
|
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
|
ctx.save_for_backward(input_, weight)
|
|
|
|
output = torch.matmul(input_, weight)
|
|
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
|
|
|
if bias is not None:
|
|
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
|
|
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
|
|
# dist.broadcast(bias,
|
|
# src=src_rank,
|
|
# group=gpc.get_group(output_parallel_mode))
|
|
# bias = all_gather(bias, -1, weight_parallel_mode)
|
|
output += bias
|
|
# ctx.src_rank = src_rank
|
|
|
|
# ctx.save_for_backward(input_, weight)
|
|
# output = torch.matmul(input_, weight)
|
|
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
|
|
# output += bias
|
|
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
ctx.input_dim = input_dim
|
|
ctx.weight_dim = weight_dim
|
|
ctx.output_dim = output_dim
|
|
return output
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
input_, weight = ctx.saved_tensors
|
|
with torch.no_grad():
|
|
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
|
# dist.all_reduce(input_grad,
|
|
# group=gpc.get_group(ctx.input_parallel_mode))
|
|
# weight_grad = torch.matmul(
|
|
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
|
# output_grad.reshape(-1, output_grad.shape[-1]))
|
|
# dist.all_reduce(weight_grad,
|
|
# group=gpc.get_group(ctx.weight_parallel_mode))
|
|
|
|
# bias_grad = torch.sum(output_grad,
|
|
# dim=tuple(
|
|
# range(len(output_grad.shape))[:-1]))
|
|
# bias_grad = reduce_scatter(bias_grad, -1,
|
|
# ctx.weight_parallel_mode)
|
|
# dist.reduce(bias_grad,
|
|
# dst=ctx.src_rank,
|
|
# group=gpc.get_group(ctx.output_parallel_mode))
|
|
# if gpc.get_local_rank(
|
|
# ctx.output_parallel_mode) != gpc.get_local_rank(
|
|
# ctx.input_parallel_mode):
|
|
# bias_grad = None
|
|
|
|
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
|
|
# weight = all_gather(weight, ctx.weight_dim,
|
|
# ctx.weight_parallel_mode)
|
|
|
|
output_grad = all_gather(output_grad, ctx.output_dim,
|
|
ctx.output_parallel_mode)
|
|
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
|
|
|
|
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
|
|
|
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
|
|
ctx.input_parallel_mode,
|
|
async_op=True)
|
|
weight_grad = torch.matmul(
|
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
|
output_grad.reshape(-1, output_grad.shape[-1]))
|
|
|
|
# weight_grad = torch.matmul(
|
|
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
|
# output_grad.reshape(-1, output_grad.shape[-1]))
|
|
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
|
|
# ctx.weight_parallel_mode)
|
|
if ctx.use_bias:
|
|
bias_grad = torch.sum(output_grad,
|
|
dim=tuple(
|
|
range(len(output_grad.shape))[:-1]))
|
|
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
|
|
# dist.all_reduce(bias_grad,
|
|
# group=gpc.get_group(ctx.weight_parallel_mode))
|
|
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
|
|
|
|
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
|
|
|
input_op.wait()
|
|
weight_op.wait()
|
|
if ctx.use_bias:
|
|
bias_grad = weight_grad[-1]
|
|
weight_grad = weight_grad[:-1]
|
|
|
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
|
|
|
|
|
class layer_norm_3d(torch.autograd.Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
|
|
normalized_shape: int, eps: float,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
|
# mean = torch.sum(input_, dim=-1)
|
|
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
|
|
# mean /= normalized_shape
|
|
# mu = input_ - mean
|
|
# var = torch.sum(torch.pow(mu, 2), dim=-1)
|
|
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
|
|
# var /= normalized_shape
|
|
# std_dev = torch.sqrt(var + eps)
|
|
# ctx.save_for_backward(input_, mu, std_dev, weight)
|
|
|
|
# output = weight * mu / std_dev + bias
|
|
|
|
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
|
|
output_parallel_mode) / normalized_shape
|
|
mu = input_ - mean
|
|
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
|
|
output_parallel_mode) / normalized_shape
|
|
sigma = torch.sqrt(var + eps)
|
|
|
|
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
|
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
|
# transforms = torch.stack([weight, bias]).contiguous()
|
|
# dist.broadcast(transforms,
|
|
# src=src_rank,
|
|
# group=gpc.get_group(input_parallel_mode))
|
|
# transforms = all_gather(transforms, -1, weight_parallel_mode)
|
|
# weight, bias = transforms[0], transforms[1]
|
|
|
|
ctx.save_for_backward(mu, sigma, weight)
|
|
|
|
z = mu / sigma
|
|
output = weight * z + bias
|
|
|
|
# ctx.src_rank = src_rank
|
|
ctx.normalized_shape = normalized_shape
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
mu, sigma, weight = ctx.saved_tensors
|
|
with torch.no_grad():
|
|
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
|
grads = torch.stack([bias_grad, weight_grad]).contiguous()
|
|
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
|
|
grads = all_reduce(grads, ctx.weight_parallel_mode)
|
|
grads = all_reduce(grads, ctx.input_parallel_mode)
|
|
bias_grad, weight_grad = grads[0], grads[1]
|
|
|
|
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
|
|
# dist.reduce(grads,
|
|
# dst=ctx.src_rank,
|
|
# group=gpc.get_group(ctx.input_parallel_mode))
|
|
# if gpc.get_local_rank(
|
|
# ctx.input_parallel_mode) == gpc.get_local_rank(
|
|
# ctx.output_parallel_mode):
|
|
# bias_grad, weight_grad = grads[0], grads[1]
|
|
# else:
|
|
# bias_grad, weight_grad = None, None
|
|
|
|
dz = output_grad * weight
|
|
dvar = dz * mu * (-0.5) * sigma**(-3)
|
|
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
|
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
|
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
|
|
|
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
|
|
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
|
|
|
|
|
class Matmul_AB_3D(torch.autograd.Function):
|
|
"""Matrix multiplication for :math:`C = AB`
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any,
|
|
A: Tensor,
|
|
B: Tensor,
|
|
depth: int,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode,
|
|
input_dim: int = 0,
|
|
weight_dim: int = -1,
|
|
output_dim: int = 0) -> Tensor:
|
|
# A: [m/q^2, n, k/q]
|
|
# B: [k/q, h/q^2]
|
|
# C: [m/q^2, n, h/q]
|
|
ctx.save_for_backward(A, B)
|
|
|
|
assert A.shape[-1] == B.shape[0], \
|
|
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
|
|
|
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
|
|
|
C = torch.matmul(A_temp, B_temp)
|
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
|
|
|
ctx.depth = depth
|
|
ctx.A_group_parallel_mode = input_parallel_mode
|
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
|
ctx.C_group_parallel_mode = output_parallel_mode
|
|
ctx.A_dim = input_dim
|
|
ctx.B_dim = weight_dim
|
|
ctx.C_dim = output_dim
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
A, B = ctx.saved_tensors
|
|
with torch.no_grad():
|
|
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.B_group_parallel_mode,
|
|
ctx.A_group_parallel_mode, ctx.C_dim,
|
|
ctx.B_dim, ctx.A_dim)
|
|
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
|
|
ctx.A_group_parallel_mode,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.B_group_parallel_mode, ctx.A_dim,
|
|
ctx.C_dim, ctx.B_dim)
|
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
|
|
|
|
|
class Matmul_ABT_3D(torch.autograd.Function):
|
|
"""Matrix multiplication for :math:`C = AB^T`
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any,
|
|
A: Tensor,
|
|
B: Tensor,
|
|
depth: int,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode,
|
|
input_dim: int = 0,
|
|
weight_dim: int = -1,
|
|
output_dim: int = 0) -> Tensor:
|
|
# A: [m/q^2, n, h/q]
|
|
# B: [k/q, h/q^2]
|
|
# C: [m/q^2, n, k/q]
|
|
ctx.save_for_backward(A, B)
|
|
|
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
|
|
|
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
|
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
|
|
|
ctx.depth = depth
|
|
ctx.A_group_parallel_mode = input_parallel_mode
|
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
|
ctx.C_group_parallel_mode = output_parallel_mode
|
|
ctx.A_dim = input_dim
|
|
ctx.B_dim = weight_dim
|
|
ctx.C_dim = output_dim
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
A, B = ctx.saved_tensors
|
|
with torch.no_grad():
|
|
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.B_group_parallel_mode,
|
|
ctx.A_group_parallel_mode, ctx.C_dim,
|
|
ctx.B_dim, ctx.A_dim)
|
|
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.A_group_parallel_mode,
|
|
ctx.B_group_parallel_mode, ctx.C_dim,
|
|
ctx.A_dim, ctx.B_dim)
|
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
|
|
|
|
|
class Matmul_ATB_3D(torch.autograd.Function):
|
|
"""Matrix multiplication for :math:`C = A^TB`
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any,
|
|
A: Tensor,
|
|
B: Tensor,
|
|
depth: int,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode,
|
|
input_dim: int = 0,
|
|
weight_dim: int = 0,
|
|
output_dim: int = -1) -> Tensor:
|
|
# A: [m/q^2, n, k/q]
|
|
# B: [m/q^2, n, h/q]
|
|
# C: [k/q, h/q^2]
|
|
ctx.save_for_backward(A, B)
|
|
|
|
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
|
A_temp = A_temp.reshape(-1, A.shape[-1])
|
|
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
|
B_temp = B_temp.reshape(-1, B.shape[-1])
|
|
|
|
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
|
|
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
|
|
|
ctx.depth = depth
|
|
ctx.A_group_parallel_mode = input_parallel_mode
|
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
|
ctx.C_group_parallel_mode = output_parallel_mode
|
|
ctx.A_dim = input_dim
|
|
ctx.B_dim = weight_dim
|
|
ctx.C_dim = output_dim
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
A, B = ctx.saved_tensors
|
|
with torch.no_grad():
|
|
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
|
|
ctx.B_group_parallel_mode,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.A_group_parallel_mode, ctx.B_dim,
|
|
ctx.C_dim, ctx.A_dim)
|
|
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
|
|
ctx.A_group_parallel_mode,
|
|
ctx.C_group_parallel_mode,
|
|
ctx.B_group_parallel_mode, ctx.A_dim,
|
|
ctx.C_dim, ctx.B_dim)
|
|
return A_grad, B_grad, None, None, None, None, None, None, None
|
|
|
|
|
|
class Add_3D(torch.autograd.Function):
|
|
"""Matrix add bias: :math:`C = A + b`
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
|
# input: [m/q^2, n, h/q]
|
|
# bias: [h/q^2]
|
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
|
bias_temp = bias.clone()
|
|
dist.broadcast(bias_temp,
|
|
src=src_rank,
|
|
group=gpc.get_group(input_parallel_mode))
|
|
# [h/q]
|
|
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
|
|
|
out = input_ + bias_temp
|
|
|
|
ctx.depth = depth
|
|
ctx.src_rank = src_rank
|
|
ctx.A_group_parallel_mode = input_parallel_mode
|
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
|
ctx.C_group_parallel_mode = output_parallel_mode
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
# output_grad: [m/q^2, n, h/q]
|
|
with torch.no_grad():
|
|
# [h/q]
|
|
grad = torch.sum(output_grad,
|
|
dim=tuple(range(len(output_grad.shape))[:-1]))
|
|
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
|
dist.reduce(bias_grad,
|
|
dst=ctx.src_rank,
|
|
group=gpc.get_group(ctx.A_group_parallel_mode))
|
|
if gpc.get_local_rank(
|
|
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
|
ctx.C_group_parallel_mode):
|
|
bias_grad = None
|
|
return output_grad, bias_grad, None, None, None, None
|
|
|
|
|
|
class Mul_3D(torch.autograd.Function):
|
|
"""Matrix multiplication for :math:`C = A * b`
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
|
input_parallel_mode: ParallelMode,
|
|
weight_parallel_mode: ParallelMode,
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
|
# input: [m/q^2, n, h/q]
|
|
# bias: [h/q^2]
|
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
|
# [h/q^2]
|
|
bias_temp = bias.clone()
|
|
dist.broadcast(bias_temp,
|
|
src=src_rank,
|
|
group=gpc.get_group(input_parallel_mode))
|
|
# [h/q]
|
|
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
|
|
|
# empty_cache()
|
|
ctx.save_for_backward(input_, bias_temp)
|
|
|
|
out = torch.mul(input_, bias_temp)
|
|
|
|
ctx.depth = depth
|
|
ctx.src_rank = src_rank
|
|
ctx.A_group_parallel_mode = input_parallel_mode
|
|
ctx.B_group_parallel_mode = weight_parallel_mode
|
|
ctx.C_group_parallel_mode = output_parallel_mode
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
# output_grad: [m/q^2, n, h/q]
|
|
with torch.no_grad():
|
|
input_, bias = ctx.saved_tensors
|
|
# [m/q^2, n, h/q]
|
|
input_grad = torch.mul(output_grad, bias)
|
|
# [h/q]
|
|
grad = torch.mul(output_grad, input_)
|
|
grad = torch.sum(grad,
|
|
dim=tuple(range(len(output_grad.shape))[:-1]))
|
|
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
|
dist.reduce(bias_grad,
|
|
dst=ctx.src_rank,
|
|
group=gpc.get_group(ctx.A_group_parallel_mode))
|
|
if gpc.get_local_rank(
|
|
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
|
ctx.C_group_parallel_mode):
|
|
bias_grad = None
|
|
return input_grad, bias_grad, None, None, None, None
|
|
|
|
|
|
class Sum_3D(torch.autograd.Function):
|
|
"""Compute the sum of input tensors
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any,
|
|
input_: Tensor,
|
|
dim: int,
|
|
depth: int,
|
|
parallel_mode: ParallelMode,
|
|
keepdim: bool = False) -> Tensor:
|
|
# input: [m/q^2, n, h/q]
|
|
out = torch.sum(input_, dim=dim, keepdim=keepdim)
|
|
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
|
|
|
|
ctx.input_shape = input_.shape
|
|
ctx.depth = depth
|
|
ctx.group = parallel_mode
|
|
ctx.dim = dim
|
|
return out
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
with torch.no_grad():
|
|
output_grad = output_grad.contiguous()
|
|
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
|
|
if len(output_grad.shape) < len(ctx.input_shape):
|
|
output_grad = torch.unsqueeze(output_grad, ctx.dim)
|
|
dims = [1 for _ in range(len(output_grad.shape))]
|
|
dims[ctx.dim] = ctx.input_shape[ctx.dim]
|
|
input_grad = output_grad.repeat(tuple(dims))
|
|
return input_grad, None, None, None, None, None
|
|
|
|
|
|
class Reduce_3D(torch.autograd.Function):
|
|
"""Reduce input tensors
|
|
"""
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float16)
|
|
def forward(ctx: Any, input_: Tensor, depth: int,
|
|
parallel_mode: ParallelMode) -> Tensor:
|
|
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
|
return input_.clone()
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
return output_grad, None, None
|
|
|
|
|
|
# class Slice_3D(torch.autograd.Function):
|
|
# """Slice input tensor
|
|
# """
|
|
# @staticmethod
|
|
# @custom_fwd(cast_inputs=torch.float16)
|
|
# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
|
|
# parallel_mode: ParallelMode) -> Tensor:
|
|
# rank = gpc.get_local_rank(parallel_mode)
|
|
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
|
|
|
|
# ctx.depth = depth
|
|
# ctx.parallel_mode = parallel_mode
|
|
# ctx.dim = dim
|
|
# ctx.input_shape = input_.shape
|
|
|
|
# return out
|
|
|
|
# @staticmethod
|
|
# @custom_bwd
|
|
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
# with torch.no_grad():
|
|
# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
|
# input_grad.reshape(ctx.input_shape)
|
|
# return input_grad, None, None, None
|