mirror of https://github.com/hpcaitech/ColossalAI
302 lines
11 KiB
Python
302 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch.nn as nn
|
|
|
|
from colossalai.lazy import LazyInitContext
|
|
|
|
from ._operation import hook_paramter_in_backward
|
|
from .utils import SeqParallelUtils
|
|
|
|
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
|
|
|
try:
|
|
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
|
|
|
EnableFastLayerNorm = True
|
|
except ImportError:
|
|
EnableFastLayerNorm = False
|
|
|
|
try:
|
|
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
|
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
|
|
|
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
|
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
|
super().__init__(normalized_shape, eps, elementwise_affine)
|
|
|
|
def forward(self, input):
|
|
output = super().forward(input)
|
|
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
|
return output
|
|
|
|
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
|
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
|
super().__init__(normalized_shape, eps, elementwise_affine)
|
|
|
|
def forward(self, input):
|
|
output = super().forward(input)
|
|
output = hook_paramter_in_backward(output, self.weight)
|
|
return output
|
|
|
|
except ImportError:
|
|
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
|
|
|
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
|
1024,
|
|
1536,
|
|
2048,
|
|
2304,
|
|
3072,
|
|
3840,
|
|
4096,
|
|
5120,
|
|
6144,
|
|
8192,
|
|
10240,
|
|
12288,
|
|
12800,
|
|
15360,
|
|
16384,
|
|
18432,
|
|
20480,
|
|
24576,
|
|
25600,
|
|
30720,
|
|
32768,
|
|
40960,
|
|
49152,
|
|
65536,
|
|
]
|
|
|
|
if EnableFastLayerNorm:
|
|
|
|
class FastLayerNormWithHook(FastLayerNorm):
|
|
def __init__(self, hidden_size, eps=0.00001):
|
|
super().__init__(hidden_size, eps)
|
|
|
|
def forward(self, input):
|
|
output = super().forward(input)
|
|
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
|
return output
|
|
|
|
|
|
class BaseLayerNorm(ABC):
|
|
@abstractmethod
|
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
|
|
"""
|
|
Convert a native PyTorch layer normalization module to a specific layer normalization module,
|
|
and optionally mark parameters for gradient aggregation.
|
|
|
|
Args:
|
|
module (nn.Module): The native PyTorch layer normalization module to be converted.
|
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
|
|
Returns:
|
|
nn.Module: The specific layer normalization module.
|
|
|
|
Raises:
|
|
AssertionError: If the provided module is not an instance of the supported layer normalization type.
|
|
"""
|
|
|
|
|
|
class RMSNorm(BaseLayerNorm):
|
|
r"""
|
|
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
raise NotImplementedError(
|
|
"FusedLayerNorm is not implemented as a physical class. "
|
|
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
|
|
)
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
"""
|
|
Convert a native RMSNorm module to colossalai layer norm module,
|
|
and optionally mark parameters for gradient aggregation.
|
|
|
|
Args:
|
|
module (nn.Module): The native RMSNorm module to be converted.
|
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
|
|
Returns:
|
|
nn.Module: The RMSNorm module.
|
|
"""
|
|
|
|
LazyInitContext.materialize(module)
|
|
|
|
if sp_partial_derived:
|
|
# Since gradients are computed using only a subset of the data,
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
|
|
|
return module
|
|
|
|
|
|
class LayerNorm(BaseLayerNorm):
|
|
r"""
|
|
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
raise NotImplementedError(
|
|
"LayerNorm is not implemented as a physical class. "
|
|
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
|
)
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
r"""
|
|
Convert a native pytorch layer norm module to colossalai layer norm module,
|
|
and optionally marking parameters for gradient aggregation.
|
|
|
|
Args:
|
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
|
|
Returns:
|
|
nn.Module: The LayerNorm module.
|
|
|
|
Raises:
|
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
"""
|
|
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
|
|
LazyInitContext.materialize(module)
|
|
|
|
if sp_partial_derived:
|
|
# Since gradients are computed using only a subset of the data,
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
|
|
|
return module
|
|
|
|
|
|
class FusedLayerNorm(BaseLayerNorm):
|
|
r"""
|
|
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
raise NotImplementedError(
|
|
"FusedLayerNorm is not implemented as a physical class. "
|
|
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
|
)
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
r"""
|
|
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
|
and optionally marking parameters for gradient aggregation.
|
|
|
|
Args:
|
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
|
|
Returns:
|
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
|
|
|
Raises:
|
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
"""
|
|
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes of the module
|
|
normalized_shape = module.normalized_shape
|
|
eps = module.eps
|
|
elementwise_affine = module.elementwise_affine
|
|
dtype = module.weight.dtype
|
|
device = module.weight.device
|
|
|
|
# pick the suitable layernorm implementation
|
|
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
|
|
|
|
if use_fast_ln:
|
|
if EnableFastLayerNorm:
|
|
ApexFusedLayerNorm = FastLayerNormWithHook
|
|
else:
|
|
# fall back to the normal fused layernorm is not built
|
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
|
else:
|
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
|
|
|
layernorm = (
|
|
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
|
)
|
|
layernorm.weight = module.weight
|
|
layernorm.bias = module.bias
|
|
|
|
if sp_partial_derived:
|
|
# Since gradients are computed using only a subset of the data,
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
|
|
|
return layernorm
|
|
|
|
|
|
class FusedRMSNorm(BaseLayerNorm):
|
|
"""
|
|
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
raise NotImplementedError(
|
|
"FusedRMSNorm is not implemented as a physical class. "
|
|
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
|
|
)
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
r"""
|
|
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
|
|
and optionally marking parameters for gradient aggregation.
|
|
|
|
Args:
|
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
|
|
Returns:
|
|
nn.Module: FusedRMSNorm module.
|
|
"""
|
|
try:
|
|
pass
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
|
)
|
|
|
|
LazyInitContext.materialize(module)
|
|
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
|
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
|
normalized_shape = module.weight.shape[0]
|
|
eps = module.variance_epsilon
|
|
elementwise_affine = True
|
|
else:
|
|
# get the attributes of the module
|
|
normalized_shape = module.normalized_shape
|
|
eps = module.eps
|
|
elementwise_affine = module.elementwise_affine
|
|
|
|
rmsnorm = FusedRMSNormWithHook(
|
|
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
|
)
|
|
|
|
rmsnorm.weight = module.weight
|
|
|
|
if sp_partial_derived:
|
|
# Since gradients are computed using only a subset of the data,
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
|
|
|
|
return rmsnorm
|