mirror of https://github.com/hpcaitech/ColossalAI
[TP] allow layernorm without bias (#750)
parent
3d7dc46d33
commit
b8899e0905
|
@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
|
|||
from ..parallel_2p5d import LayerNorm2p5D
|
||||
from ..parallel_3d import LayerNorm3D
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import VanillaLayerNorm
|
||||
from ._utils import ColossalaiModule
|
||||
|
||||
_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
_parallel_layernorm = {
|
||||
None: VanillaLayerNorm,
|
||||
"1d": LayerNorm1D,
|
||||
"2d": LayerNorm2D,
|
||||
"2.5d": LayerNorm2p5D,
|
||||
"3d": LayerNorm3D,
|
||||
}
|
||||
|
||||
|
||||
class LayerNorm(ColossalaiModule):
|
||||
|
@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
|
|||
|
||||
Args:
|
||||
normalized_shape (int): input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05
|
||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel is None:
|
||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from ..vanilla import VanillaPatchEmbedding
|
||||
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
|
@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
|
|||
r"""
|
||||
Layer Normalization for colossalai
|
||||
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
||||
:type eps: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
Args:
|
||||
normalized_shape (int): input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
|
||||
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
|
|
|
@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
|
|||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
# layer norm config
|
||||
|
@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
|
|||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
|
|||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
|
@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
|
|||
|
||||
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL)
|
||||
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
if self.bias is not None:
|
||||
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
else:
|
||||
output = torch.mul(scale, output)
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
# layer norm config
|
||||
|
@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
|
@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
||||
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
if self.bias is not None:
|
||||
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
else:
|
||||
output = torch.mul(scale, output)
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
|
@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z + bias
|
||||
output = weight * z
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
|
@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
grads = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
|
||||
grads = all_reduce(grads, ctx.weight_parallel_mode)
|
||||
grads = all_reduce(grads, ctx.input_parallel_mode)
|
||||
bias_grad, weight_grad = grads[0], grads[1]
|
||||
weight_grad = output_grad * mu / sigma
|
||||
if ctx.use_bias:
|
||||
bias_grad = output_grad
|
||||
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||
else:
|
||||
bias_grad = None
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
|
||||
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
|
||||
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
|
||||
if ctx.use_bias:
|
||||
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
|
@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
r"""3D parallel Layernorm.
|
||||
|
|
|
@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
|
|||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
|
||||
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
|
@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
|
|||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(),
|
||||
dtype=dtype))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
|
||||
device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.zeros_()(self.bias)
|
||||
init.ones_()(self.weight)
|
||||
if self.bias is not None:
|
||||
init.zeros_()(self.bias)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
|
|||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
|
||||
WrappedDropout, WrappedDropPath
|
||||
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm,
|
||||
VanillaPatchEmbedding, WrappedDropout, WrappedDropPath)
|
||||
|
||||
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
|
||||
'WrappedDropout', 'WrappedDropPath']
|
||||
__all__ = [
|
||||
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier",
|
||||
"DropPath", "WrappedDropout", "WrappedDropPath"
|
||||
]
|
||||
|
|
|
@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
|
|||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return F.linear(input_, self.weight, self.bias)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Layer Normalization for colossalai
|
||||
|
||||
Args:
|
||||
normalized_shape (int): input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
|
||||
|
|
Loading…
Reference in New Issue