[TP] allow layernorm without bias (#750)

pull/756/head
アマデウス 2022-04-14 11:43:56 +08:00 committed by GitHub
parent 3d7dc46d33
commit b8899e0905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 134 additions and 53 deletions

View File

@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
from ..parallel_2p5d import LayerNorm2p5D
from ..parallel_3d import LayerNorm3D
from ..utils import get_tensor_parallel_mode
from ..vanilla import VanillaLayerNorm
from ._utils import ColossalaiModule
_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_layernorm = {
None: VanillaLayerNorm,
"1d": LayerNorm1D,
"2d": LayerNorm2D,
"2.5d": LayerNorm2p5D,
"3d": LayerNorm3D,
}
class LayerNorm(ColossalaiModule):
@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None:
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())

View File

@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
r"""
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):

View File

@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
super().__init__()
# layer norm config
@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
if bias:
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
else:
self.bias = None
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
if self.bias is not None:
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
else:
output = torch.mul(scale, output)
return output

View File

@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
super().__init__()
# layer norm config
@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
if bias:
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
else:
self.bias = None
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
if self.bias is not None:
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
else:
output = torch.mul(scale, output)
return output

View File

@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
output = weight * z
if bias is not None:
output = output + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
grads = torch.stack([bias_grad, weight_grad]).contiguous()
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
grads = all_reduce(grads, ctx.weight_parallel_mode)
grads = all_reduce(grads, ctx.input_parallel_mode)
bias_grad, weight_grad = grads[0], grads[1]
weight_grad = output_grad * mu / sigma
if ctx.use_bias:
bias_grad = output_grad
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
else:
bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
r"""3D parallel Layernorm.

View File

@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(),
dtype=dtype))
if bias:
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.variance_epsilon = eps
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self) -> None:
init.zeros_()(self.bias)
init.ones_()(self.weight)
if self.bias is not None:
init.zeros_()(self.bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \

View File

@ -1,5 +1,7 @@
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
WrappedDropout, WrappedDropPath
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm,
VanillaPatchEmbedding, WrappedDropout, WrappedDropPath)
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
'WrappedDropout', 'WrappedDropPath']
__all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier",
"DropPath", "WrappedDropout", "WrappedDropPath"
]

View File

@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
super().__init__()
self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)