mirror of https://github.com/hpcaitech/ColossalAI
[TP] allow layernorm without bias (#750)
parent
3d7dc46d33
commit
b8899e0905
|
@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
|
||||||
from ..parallel_2p5d import LayerNorm2p5D
|
from ..parallel_2p5d import LayerNorm2p5D
|
||||||
from ..parallel_3d import LayerNorm3D
|
from ..parallel_3d import LayerNorm3D
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ..vanilla import VanillaLayerNorm
|
||||||
from ._utils import ColossalaiModule
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
_parallel_layernorm = {
|
||||||
|
None: VanillaLayerNorm,
|
||||||
|
"1d": LayerNorm1D,
|
||||||
|
"2d": LayerNorm2D,
|
||||||
|
"2.5d": LayerNorm2p5D,
|
||||||
|
"3d": LayerNorm3D,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(ColossalaiModule):
|
class LayerNorm(ColossalaiModule):
|
||||||
|
@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
normalized_shape (int): input shape from an expected input of size.
|
normalized_shape (int): input shape from an expected input of size.
|
||||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05
|
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if tensor_parallel is None:
|
if tensor_parallel is None:
|
||||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from ..vanilla import VanillaPatchEmbedding
|
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..colossalai_layer._utils import ColossalaiModule
|
from ..colossalai_layer._utils import ColossalaiModule
|
||||||
|
@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
|
||||||
r"""
|
r"""
|
||||||
Layer Normalization for colossalai
|
Layer Normalization for colossalai
|
||||||
|
|
||||||
:param normalized_shape: input shape from an expected input
|
Args:
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
normalized_shape (int): input shape from an expected input of size.
|
||||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
:type normalized_shape: int
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
:type eps: float, optional
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
:param dtype: The dtype of parameters, defaults to None
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
:type dtype: torch.dtype, optional
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
|
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||||
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
|
||||||
super().__init__(norm)
|
super().__init__(norm)
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
|
|
@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# layer norm config
|
# layer norm config
|
||||||
|
@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
|
||||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
|
||||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
if bias:
|
||||||
|
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
self._set_tensor_parallel_attributes()
|
self._set_tensor_parallel_attributes()
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
def _set_tensor_parallel_attributes(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
if self.bias is not None:
|
||||||
|
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
local_state = OrderedDict()
|
local_state = OrderedDict()
|
||||||
|
@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
weight_key = prefix + 'weight'
|
weight_key = prefix + 'weight'
|
||||||
bias_key = prefix + 'bias'
|
bias_key = prefix + 'bias'
|
||||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
# gather in column groups
|
# gather in column groups
|
||||||
local_state = gather_tensor_parallel_state_dict(
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
|
||||||
|
|
||||||
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
||||||
ParallelMode.PARALLEL_2D_COL)
|
ParallelMode.PARALLEL_2D_COL)
|
||||||
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
|
||||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
|
||||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
|
||||||
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
|
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||||
output = torch.addcmul(bias, scale, output)
|
if self.bias is not None:
|
||||||
|
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||||
|
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||||
|
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||||
|
self.tensor_parallel_size)
|
||||||
|
output = torch.addcmul(bias, scale, output)
|
||||||
|
else:
|
||||||
|
output = torch.mul(scale, output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# layer norm config
|
# layer norm config
|
||||||
|
@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
|
||||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
if bias:
|
||||||
|
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
self._set_tensor_parallel_attribute()
|
self._set_tensor_parallel_attribute()
|
||||||
|
|
||||||
def _set_tensor_parallel_attribute(self):
|
def _set_tensor_parallel_attribute(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
if self.bias is not None:
|
||||||
|
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
local_state = OrderedDict()
|
local_state = OrderedDict()
|
||||||
|
@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
weight_key = prefix + 'weight'
|
weight_key = prefix + 'weight'
|
||||||
bias_key = prefix + 'bias'
|
bias_key = prefix + 'bias'
|
||||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
# gather in column groups
|
# gather in column groups
|
||||||
local_state = gather_tensor_parallel_state_dict(
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||||
|
|
||||||
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
||||||
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
|
||||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
|
||||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
|
||||||
self.tensor_parallel_size)
|
|
||||||
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size)
|
self.tensor_parallel_size)
|
||||||
output = torch.addcmul(bias, scale, output)
|
if self.bias is not None:
|
||||||
|
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||||
|
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||||
|
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||||
|
self.tensor_parallel_size)
|
||||||
|
output = torch.addcmul(bias, scale, output)
|
||||||
|
else:
|
||||||
|
output = torch.mul(scale, output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||||
output_parallel_mode: ParallelMode) -> Tensor:
|
output_parallel_mode: ParallelMode) -> Tensor:
|
||||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||||
|
@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
ctx.save_for_backward(mu, sigma, weight)
|
ctx.save_for_backward(mu, sigma, weight)
|
||||||
|
|
||||||
z = mu / sigma
|
z = mu / sigma
|
||||||
output = weight * z + bias
|
output = weight * z
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
|
||||||
|
ctx.use_bias = bias is not None
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
ctx.input_parallel_mode = input_parallel_mode
|
ctx.input_parallel_mode = input_parallel_mode
|
||||||
ctx.weight_parallel_mode = weight_parallel_mode
|
ctx.weight_parallel_mode = weight_parallel_mode
|
||||||
|
@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
mu, sigma, weight = ctx.saved_tensors
|
mu, sigma, weight = ctx.saved_tensors
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
weight_grad = output_grad * mu / sigma
|
||||||
grads = torch.stack([bias_grad, weight_grad]).contiguous()
|
if ctx.use_bias:
|
||||||
grads = torch.sum(grads, dim=tuple(range(len(grads.shape))[1:-1]))
|
bias_grad = output_grad
|
||||||
grads = all_reduce(grads, ctx.weight_parallel_mode)
|
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||||
grads = all_reduce(grads, ctx.input_parallel_mode)
|
else:
|
||||||
bias_grad, weight_grad = grads[0], grads[1]
|
bias_grad = None
|
||||||
|
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
|
||||||
|
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
|
||||||
|
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
|
||||||
|
if ctx.use_bias:
|
||||||
|
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
|
||||||
|
|
||||||
dz = output_grad * weight
|
dz = output_grad * weight
|
||||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||||
|
@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||||
output_parallel_mode: ParallelMode) -> Tensor:
|
output_parallel_mode: ParallelMode) -> Tensor:
|
||||||
r"""3D parallel Layernorm.
|
r"""3D parallel Layernorm.
|
||||||
|
|
|
@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
|
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
|
||||||
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
|
@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(),
|
if bias:
|
||||||
dtype=dtype))
|
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
|
||||||
|
device=get_current_device(), dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
self._set_tensor_parallel_attributes()
|
self._set_tensor_parallel_attributes()
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self) -> None:
|
def _set_tensor_parallel_attributes(self) -> None:
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
if self.bias is not None:
|
||||||
|
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
init.zeros_()(self.bias)
|
|
||||||
init.ones_()(self.weight)
|
init.ones_()(self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
init.zeros_()(self.bias)
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
local_state = OrderedDict()
|
local_state = OrderedDict()
|
||||||
|
@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
weight_key = prefix + 'weight'
|
weight_key = prefix + 'weight'
|
||||||
bias_key = prefix + 'bias'
|
bias_key = prefix + 'bias'
|
||||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
# gather in output groups
|
# gather in output groups
|
||||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
|
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm,
|
||||||
WrappedDropout, WrappedDropPath
|
VanillaPatchEmbedding, WrappedDropout, WrappedDropPath)
|
||||||
|
|
||||||
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
|
__all__ = [
|
||||||
'WrappedDropout', 'WrappedDropPath']
|
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier",
|
||||||
|
"DropPath", "WrappedDropout", "WrappedDropPath"
|
||||||
|
]
|
||||||
|
|
|
@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
return F.linear(input_, self.weight, self.bias)
|
return F.linear(input_, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class VanillaLayerNorm(nn.Module):
|
||||||
|
r"""
|
||||||
|
Layer Normalization for colossalai
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (int): input shape from an expected input of size.
|
||||||
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||||
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.normalized_shape = (normalized_shape,)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
|
||||||
|
|
Loading…
Reference in New Issue