mirror of https://github.com/hpcaitech/ColossalAI
[tensorparallel] fixed tp layers (#1938)
parent
cf68cc92ac
commit
e52f9d9109
|
@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input:
|
||||
if not parallel_input and not gather_output:
|
||||
layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
|
|
|
@ -4,13 +4,15 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||
|
||||
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||
|
||||
|
||||
class _Linear3D(torch.autograd.Function):
|
||||
|
@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
input_op.wait()
|
||||
input_op.wait()
|
||||
|
||||
return input_grad, weight_grad, None, None, None, None
|
||||
|
||||
|
@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
else:
|
||||
weight_grad = None
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
else:
|
||||
weight_grad = None
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight)
|
||||
input_grad = torch.matmul(output_grad, weight)
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_op.wait()
|
||||
input_op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
|
|||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
|
||||
mu = x - mean
|
||||
var = sqr_mean - mean**2
|
||||
sigma = torch.sqrt(var + eps)
|
||||
z = mu / sigma
|
||||
output = weight * z + bias
|
||||
|
||||
return output, mu, sigma
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_backward(grad, mu, sigma, weight):
|
||||
# dbias, dweight = grad, grad * mu / sigma
|
||||
dz = grad * weight
|
||||
dmu = dz / sigma
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
dmean = -dmu
|
||||
dvar = torch.sum(dvar, -1, keepdim=True)
|
||||
dmean = torch.sum(dmean, -1, keepdim=True)
|
||||
|
||||
return dmu, dmean, dvar
|
||||
|
||||
|
||||
class _Layernorm3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
bias_id: int,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.weight_id = weight_id
|
||||
ctx.bias_id = bias_id
|
||||
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
mu = input_ - mean
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
sigma = torch.sqrt(var + eps)
|
||||
sum_ = torch.sum(input_, dim=-1, keepdim=True)
|
||||
sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
|
||||
mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
|
||||
|
||||
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
|
||||
|
||||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z + bias
|
||||
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
|
||||
|
||||
|
@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
|
||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
|
||||
dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
|
||||
input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
|
||||
|
||||
|
@ -351,8 +363,6 @@ def layernorm_3d(
|
|||
bias: Tensor,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
|
@ -368,9 +378,8 @@ def layernorm_3d(
|
|||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float): a value added to the denominator for numerical stability
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
|
@ -384,8 +393,6 @@ def layernorm_3d(
|
|||
id(bias),
|
||||
normalized_shape,
|
||||
eps,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
input_x_weight_parallel_mode,
|
||||
)
|
||||
|
|
|
@ -5,6 +5,9 @@ from typing import Callable
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.communication import all_reduce, broadcast
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context import ParallelMode, seed
|
||||
|
@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
|
|||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict)
|
||||
from colossalai.utils.checkpointing import (
|
||||
broadcast_state_dict,
|
||||
gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
|
||||
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
|
||||
from ._operation import (
|
||||
all_gather_tensor_3d,
|
||||
classifier_3d,
|
||||
layernorm_3d,
|
||||
linear_3d,
|
||||
reduce_scatter_tensor_3d,
|
||||
split_batch_3d,
|
||||
split_tensor_3d,
|
||||
vocab_parallel_classifier_3d,
|
||||
)
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
|
|||
self.bias,
|
||||
self.normalized_shape,
|
||||
self.variance_epsilon,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
self.input_x_weight_parallel_mode,
|
||||
)
|
||||
|
@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
|
|
Loading…
Reference in New Issue