Remove CohereLayerNorm and use existing layernorm

pull/5818/head
GuangyaoZhang 5 months ago
parent c9025ebd7c
commit 8c3f524660

@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -23,8 +23,6 @@ __all__ = [
"RMSNorm", "RMSNorm",
"FusedLayerNorm", "FusedLayerNorm",
"FusedRMSNorm", "FusedRMSNorm",
"CohereLayerNorm",
"FusedCohereLayerNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",
"ParallelModule", "ParallelModule",
"PaddingEmbedding", "PaddingEmbedding",

@ -4,7 +4,6 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
@ -141,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
class LayerNorm(BaseLayerNorm): class LayerNorm(BaseLayerNorm):
r""" r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface. This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface.
""" """
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"LayerNorm is not implemented as a physical class. " "LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module." "It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
Convert a native pytorch layer norm module to colossalai layer norm module, Convert a native LayerNorm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation. and optionally marking parameters for gradient aggregation.
Args: Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. module (nn.Module): The native LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns: Returns:
nn.Module: The LayerNorm module. nn.Module: The colossalai LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -175,6 +171,7 @@ class LayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation. # aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
if module.bias is not None:
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module return module
@ -188,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. " "FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." "It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, Convert a native LayerNorm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation. and optionally marking parameters for gradient aggregation.
Args: Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. module (nn.Module): The native LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns: Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm]. nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes of the module # get the attributes of the module
normalized_shape = module.normalized_shape normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.eps eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = module.elementwise_affine elementwise_affine = getattr(module, "elementwise_affine", True)
dtype = module.weight.dtype dtype = module.weight.dtype
device = module.weight.device device = module.weight.device
@ -230,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
ApexFusedLayerNorm = FusedLayerNormWithHook ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError: except NameError:
warnings.warn( warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
) )
return module return module
@ -238,6 +233,7 @@ class FusedLayerNorm(BaseLayerNorm):
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
) )
layernorm.weight = module.weight layernorm.weight = module.weight
if module.bias is not None:
layernorm.bias = module.bias layernorm.bias = module.bias
if sp_partial_derived: if sp_partial_derived:
@ -250,114 +246,6 @@ class FusedLayerNorm(BaseLayerNorm):
return layernorm return layernorm
class CohereLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"CohereLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a CohereLayerNorm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of CohereLayerNorm
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class FusedCohereLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedCohereLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm.
"""
LazyInitContext.materialize(module)
# get the attributes of the module
normalized_shape = module.weight.size(0)
eps = module.variance_epsilon
elementwise_affine = True
dtype = module.weight.dtype
device = module.weight.device
# pick the suitable layernorm implementation
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
if use_fast_ln:
if EnableFastLayerNorm:
ApexFusedLayerNorm = FastLayerNormWithHook
else:
# fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook
else:
try:
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError:
warnings.warn(
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
)
return module
layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
)
layernorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
class FusedRMSNorm(BaseLayerNorm): class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.

@ -7,8 +7,8 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
CohereLayerNorm, FusedLayerNorm,
FusedCohereLayerNorm, LayerNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
PaddingEmbedding, PaddingEmbedding,
@ -64,9 +64,9 @@ class CommandPolicy(Policy):
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = FusedCohereLayerNorm norm_cls = FusedLayerNorm
else: else:
norm_cls = CohereLayerNorm norm_cls = LayerNorm
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False

Loading…
Cancel
Save