diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py index 4c6150bd5..86861d302 100644 --- a/colossalai/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D from ..parallel_2p5d import LayerNorm2p5D from ..parallel_3d import LayerNorm3D from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaLayerNorm from ._utils import ColossalaiModule -_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} +_parallel_layernorm = { + None: VanillaLayerNorm, + "1d": LayerNorm1D, + "2d": LayerNorm2D, + "2.5d": LayerNorm2p5D, + "3d": LayerNorm3D, +} class LayerNorm(ColossalaiModule): @@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule): Args: normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. - eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05 + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ - def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 2daf875f8..9459e7139 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_ from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn.parameter import Parameter -from ..vanilla import VanillaPatchEmbedding +from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule @@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule): r""" Layer Normalization for colossalai - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - :type normalized_shape: int - :param eps: a value added to the denominator for numerical stability, defaults to 1e-05 - :type eps: float, optional - :param dtype: The dtype of parameters, defaults to None - :type dtype: torch.dtype, optional + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ - def __init__(self, normalized_shape: int, eps=1e-05, dtype=None): - norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype) super().__init__(norm) def _load_from_state_dict(self, state_dict, prefix, *args): diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 1ba7768ac..cec7cb8f7 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ - def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None): + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): super().__init__() # layer norm config @@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer): factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) - self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer): def _save_to_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( @@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer): output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL) - bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - output = torch.addcmul(bias, scale, output) + if self.bias is not None: + bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) return output diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index dd188885f..d89150642 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ - def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None): + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): super().__init__() # layer norm config @@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer): factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) - self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None self._set_tensor_parallel_attribute() def _set_tensor_parallel_attribute(self): set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer): def _save_to_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( @@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer): Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) - bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - output = torch.addcmul(bias, scale, output) + if self.bias is not None: + bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) return output diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 01251535f..b4d6f734e 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, + def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape @@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function): ctx.save_for_backward(mu, sigma, weight) z = mu / sigma - output = weight * z + bias + output = weight * z + if bias is not None: + output = output + bias + ctx.use_bias = bias is not None ctx.normalized_shape = normalized_shape ctx.input_parallel_mode = input_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode @@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function): def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: mu, sigma, weight = ctx.saved_tensors with torch.no_grad(): - bias_grad, weight_grad = output_grad, output_grad * mu / sigma - grads = torch.stack([bias_grad, weight_grad]).contiguous() - grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1])) - grads = all_reduce(grads, ctx.weight_parallel_mode) - grads = all_reduce(grads, ctx.input_parallel_mode) - bias_grad, weight_grad = grads[0], grads[1] + weight_grad = output_grad * mu / sigma + if ctx.use_bias: + bias_grad = output_grad + weight_grad = torch.stack([bias_grad, weight_grad]).contiguous() + else: + bias_grad = None + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1])) + weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode) + weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode) + if ctx.use_bias: + bias_grad, weight_grad = weight_grad[0], weight_grad[1] dz = output_grad * weight dvar = dz * mu * (-0.5) * sigma**(-3) @@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function): return input_grad, weight_grad, bias_grad, None, None, None, None, None -def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, +def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: r"""3D parallel Layernorm. diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index f437c44e0..654d5d07f 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. + bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ - def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None): + def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer): self.weight = Parameter( torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) - self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), - dtype=dtype)) + if bias: + self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, + device=get_current_device(), dtype=dtype)) + else: + self.bias = None self.variance_epsilon = eps self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) - set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) def reset_parameters(self) -> None: - init.zeros_()(self.bias) init.ones_()(self.weight) + if self.bias is not None: + init.zeros_()(self.bias) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer): def _save_to_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py index 14af80027..172f28967 100644 --- a/colossalai/nn/layer/vanilla/__init__.py +++ b/colossalai/nn/layer/vanilla/__init__.py @@ -1,5 +1,7 @@ -from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \ - WrappedDropout, WrappedDropPath +from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, + VanillaPatchEmbedding, WrappedDropout, WrappedDropPath) -__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath', - 'WrappedDropout', 'WrappedDropPath'] +__all__ = [ + "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", + "DropPath", "WrappedDropout", "WrappedDropPath" +] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py index 0d5bcbaf8..dfc37af13 100644 --- a/colossalai/nn/layer/vanilla/layers.py +++ b/colossalai/nn/layer/vanilla/layers.py @@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module): def forward(self, input_: Tensor) -> Tensor: return F.linear(input_, self.weight, self.bias) + + +@LAYERS.register_module +class VanillaLayerNorm(nn.Module): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + super().__init__() + + self.normalized_shape = (normalized_shape,) + self.variance_epsilon = eps + + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + + self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)