mirror of https://github.com/hpcaitech/ColossalAI
74 lines
3.0 KiB
Python
74 lines
3.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from collections import OrderedDict
|
|
|
|
from colossalai.context import ParallelMode, seed
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
from colossalai.kernel import LayerNorm
|
|
from colossalai.nn import init as init
|
|
from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule
|
|
from colossalai.utils.checkpointing import broadcast_state_dict
|
|
|
|
Fast_LN = None
|
|
try:
|
|
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
|
Fast_LN = FastLayerNorm
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class LayerNorm1D(ColossalaiModule):
|
|
r"""
|
|
Layer Normalization for colossalai
|
|
|
|
Args:
|
|
normalized_shape (int): input shape from an expected input of size.
|
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
"""
|
|
|
|
_fast_ln_supported_sizes = [
|
|
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
|
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
|
]
|
|
|
|
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
|
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
|
|
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
|
|
else:
|
|
norm = None
|
|
try:
|
|
from apex.normalization import FusedLayerNorm
|
|
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
|
except ImportError:
|
|
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
|
|
super().__init__(norm)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + 'weight'
|
|
bias_key = prefix + 'bias'
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|