diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index b64488a12..e96abd87e 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): parallel_input = get_parallel_input() - if not parallel_input: + if not parallel_input and not gather_output: layer = Linear1D_Col(in_features, out_features, bias=bias, dtype=dtype, - gather_output=gather_output, skip_bias_add=skip_bias_add, weight_initializer=weight_initializer, bias_initializer=bias_initializer) diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index aeba5cc9d..885d06e6d 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -4,13 +4,15 @@ from typing import Optional, Tuple import torch -from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from ._utils import get_parallel_mode_from_env, push_async_grad + +from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +from ._utils import get_parallel_mode_from_env, push_async_grad class _Linear3D(torch.autograd.Function): @@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors - with torch.no_grad(): - output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) - input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) - weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - input_op.wait() + input_op.wait() return input_grad, weight_grad, None, None, None, None @@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors - with torch.no_grad(): - weight_grad = torch.matmul( - output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) - weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) - if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): - weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - else: - weight_grad = None + weight_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) + if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): + weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + else: + weight_grad = None - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - else: - bias_grad = None + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None - input_grad = torch.matmul(output_grad, weight) + input_grad = torch.matmul(output_grad, weight) return input_grad, weight_grad, bias_grad, None, None, None, None, None @@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors - with torch.no_grad(): - output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) - input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) - weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - else: - bias_grad = None + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None - input_op.wait() + input_op.wait() return input_grad, weight_grad, bias_grad, None, None, None, None, None @@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d( ) +@torch.jit.script +def norm_forward(x, mean, sqr_mean, weight, bias, eps): + mu = x - mean + var = sqr_mean - mean**2 + sigma = torch.sqrt(var + eps) + z = mu / sigma + output = weight * z + bias + + return output, mu, sigma + + +@torch.jit.script +def norm_backward(grad, mu, sigma, weight): + # dbias, dweight = grad, grad * mu / sigma + dz = grad * weight + dmu = dz / sigma + dvar = dz * mu * (-0.5) * sigma**(-3) + dmean = -dmu + dvar = torch.sum(dvar, -1, keepdim=True) + dmean = torch.sum(dmean, -1, keepdim=True) + + return dmu, dmean, dvar + + class _Layernorm3D(torch.autograd.Function): @staticmethod @@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function): bias_id: int, normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode, input_x_weight_parallel_mode: ParallelMode, ) -> Tensor: ctx.weight_id = weight_id ctx.bias_id = bias_id - mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape - mu = input_ - mean - var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape - sigma = torch.sqrt(var + eps) + sum_ = torch.sum(input_, dim=-1, keepdim=True) + sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) + mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape + + output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) ctx.save_for_backward(mu, sigma, weight) - z = mu / sigma - output = weight * z + bias - ctx.normalized_shape = normalized_shape - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode ctx.output_parallel_mode = output_parallel_mode ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode @@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function): @custom_bwd def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: mu, sigma, weight = ctx.saved_tensors - with torch.no_grad(): - bias_grad, weight_grad = output_grad, output_grad * mu / sigma - bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) - weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - dz = output_grad * weight - dvar = dz * mu * (-0.5) * sigma**(-3) - dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode) - dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape - dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode) - - input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape + dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) + dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) + input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None @@ -351,8 +363,6 @@ def layernorm_3d( bias: Tensor, normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode, input_x_weight_parallel_mode: ParallelMode, ) -> Tensor: @@ -368,9 +378,8 @@ def layernorm_3d( If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float): a value added to the denominator for numerical stability - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -384,8 +393,6 @@ def layernorm_3d( id(bias), normalized_shape, eps, - input_parallel_mode, - weight_parallel_mode, output_parallel_mode, input_x_weight_parallel_mode, ) diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 6b3a7f4cc..0a1db6800 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -5,6 +5,9 @@ from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + from colossalai.communication import all_reduce, broadcast from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.context import ParallelMode, seed @@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d, - reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d) -from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook +from ._operation import ( + all_gather_tensor_3d, + classifier_3d, + layernorm_3d, + linear_3d, + reduce_scatter_tensor_3d, + split_batch_3d, + split_tensor_3d, + vocab_parallel_classifier_3d, +) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group @LAYERS.register_module @@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer): self.bias, self.normalized_shape, self.variance_epsilon, - self.input_parallel_mode, - self.weight_parallel_mode, self.output_parallel_mode, self.input_x_weight_parallel_mode, ) @@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer): weight_parallel_mode=self.weight_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1)