[tensorparallel] fixed tp layers (#1938)

pull/1944/head
アマデウス 2022-11-14 17:34:03 +08:00 committed by GitHub
parent cf68cc92ac
commit e52f9d9109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 89 deletions

View File

@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input() parallel_input = get_parallel_input()
if not parallel_input: if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features, layer = Linear1D_Col(in_features,
out_features, out_features,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer, weight_initializer=weight_initializer,
bias_initializer=bias_initializer) bias_initializer=bias_initializer)

View File

@ -4,13 +4,15 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env, push_async_grad
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from ._utils import get_parallel_mode_from_env, push_async_grad
class _Linear3D(torch.autograd.Function): class _Linear3D(torch.autograd.Function):
@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors input_, weight = ctx.saved_tensors
with torch.no_grad(): output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul( weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait() input_op.wait()
return input_grad, weight_grad, None, None, None, None return input_grad, weight_grad, None, None, None, None
@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors input_, weight = ctx.saved_tensors
with torch.no_grad(): weight_grad = torch.matmul(
weight_grad = torch.matmul( output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) else:
else: weight_grad = None
weight_grad = None
if ctx.use_bias: if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else: else:
bias_grad = None bias_grad = None
input_grad = torch.matmul(output_grad, weight) input_grad = torch.matmul(output_grad, weight)
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors input_, weight = ctx.saved_tensors
with torch.no_grad(): output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul( weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias: if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else: else:
bias_grad = None bias_grad = None
input_op.wait() input_op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
) )
@torch.jit.script
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
z = mu / sigma
output = weight * z + bias
return output, mu, sigma
@torch.jit.script
def norm_backward(grad, mu, sigma, weight):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
dvar = dz * mu * (-0.5) * sigma**(-3)
dmean = -dmu
dvar = torch.sum(dvar, -1, keepdim=True)
dmean = torch.sum(dmean, -1, keepdim=True)
return dmu, dmean, dvar
class _Layernorm3D(torch.autograd.Function): class _Layernorm3D(torch.autograd.Function):
@staticmethod @staticmethod
@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
bias_id: int, bias_id: int,
normalized_shape: int, normalized_shape: int,
eps: float, eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode, input_x_weight_parallel_mode: ParallelMode,
) -> Tensor: ) -> Tensor:
ctx.weight_id = weight_id ctx.weight_id = weight_id
ctx.bias_id = bias_id ctx.bias_id = bias_id
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape sum_ = torch.sum(input_, dim=-1, keepdim=True)
mu = input_ - mean sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
ctx.save_for_backward(mu, sigma, weight) ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dz = output_grad * weight dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
dvar = dz * mu * (-0.5) * sigma**(-3) dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode) input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
@ -351,8 +363,6 @@ def layernorm_3d(
bias: Tensor, bias: Tensor,
normalized_shape: int, normalized_shape: int,
eps: float, eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode, input_x_weight_parallel_mode: ParallelMode,
) -> Tensor: ) -> Tensor:
@ -368,9 +378,8 @@ def layernorm_3d(
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note: Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@ -384,8 +393,6 @@ def layernorm_3d(
id(bias), id(bias),
normalized_shape, normalized_shape,
eps, eps,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode, output_parallel_mode,
input_x_weight_parallel_mode, input_x_weight_parallel_mode,
) )

View File

@ -5,6 +5,9 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import all_reduce, broadcast from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, from colossalai.utils.checkpointing import (
partition_tensor_parallel_state_dict) broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d, from ._operation import (
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d) all_gather_tensor_3d,
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook classifier_3d,
layernorm_3d,
linear_3d,
reduce_scatter_tensor_3d,
split_batch_3d,
split_tensor_3d,
vocab_parallel_classifier_3d,
)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
@LAYERS.register_module @LAYERS.register_module
@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
self.bias, self.bias,
self.normalized_shape, self.normalized_shape,
self.variance_epsilon, self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode, self.output_parallel_mode,
self.input_x_weight_parallel_mode, self.input_x_weight_parallel_mode,
) )
@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
weight_parallel_mode=self.weight_parallel_mode) weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1) cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)