[tensorparallel] fixed tp layers (#1938)

pull/1944/head
アマデウス 2022-11-14 17:34:03 +08:00 committed by GitHub
parent cf68cc92ac
commit e52f9d9109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 89 deletions

View File

@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input()
if not parallel_input:
if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)

View File

@ -4,13 +4,15 @@
from typing import Optional, Tuple
import torch
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env, push_async_grad
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from ._utils import get_parallel_mode_from_env, push_async_grad
class _Linear3D(torch.autograd.Function):
@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait()
input_op.wait()
return input_grad, weight_grad, None, None, None, None
@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
else:
weight_grad = None
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
else:
weight_grad = None
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight)
input_grad = torch.matmul(output_grad, weight)
return input_grad, weight_grad, bias_grad, None, None, None, None, None
@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_op.wait()
input_op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None
@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
)
@torch.jit.script
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
z = mu / sigma
output = weight * z + bias
return output, mu, sigma
@torch.jit.script
def norm_backward(grad, mu, sigma, weight):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
dvar = dz * mu * (-0.5) * sigma**(-3)
dmean = -dmu
dvar = torch.sum(dvar, -1, keepdim=True)
dmean = torch.sum(dmean, -1, keepdim=True)
return dmu, dmean, dvar
class _Layernorm3D(torch.autograd.Function):
@staticmethod
@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
bias_id: int,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.bias_id = bias_id
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
sum_ = torch.sum(input_, dim=-1, keepdim=True)
sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
@ -351,8 +363,6 @@ def layernorm_3d(
bias: Tensor,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
@ -368,9 +378,8 @@ def layernorm_3d(
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@ -384,8 +393,6 @@ def layernorm_3d(
id(bias),
normalized_shape,
eps,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
input_x_weight_parallel_mode,
)

View File

@ -5,6 +5,9 @@ from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
from ._operation import (
all_gather_tensor_3d,
classifier_3d,
layernorm_3d,
linear_3d,
reduce_scatter_tensor_3d,
split_batch_3d,
split_tensor_3d,
vocab_parallel_classifier_3d,
)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
@LAYERS.register_module
@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
self.input_x_weight_parallel_mode,
)
@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)