mirror of https://github.com/hpcaitech/ColossalAI
Remove CohereLayerNorm and use existing layernorm
parent
fe2e74c03a
commit
7a2b08646f
|
@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||||
from .loss import cross_entropy_1d
|
from .loss import cross_entropy_1d
|
||||||
from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
|
|
||||||
|
@ -23,8 +23,6 @@ __all__ = [
|
||||||
"RMSNorm",
|
"RMSNorm",
|
||||||
"FusedLayerNorm",
|
"FusedLayerNorm",
|
||||||
"FusedRMSNorm",
|
"FusedRMSNorm",
|
||||||
"CohereLayerNorm",
|
|
||||||
"FusedCohereLayerNorm",
|
|
||||||
"FusedLinear1D_Col",
|
"FusedLinear1D_Col",
|
||||||
"ParallelModule",
|
"ParallelModule",
|
||||||
"PaddingEmbedding",
|
"PaddingEmbedding",
|
||||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
|
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
|
@ -141,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
|
||||||
|
|
||||||
class LayerNorm(BaseLayerNorm):
|
class LayerNorm(BaseLayerNorm):
|
||||||
r"""
|
r"""
|
||||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"LayerNorm is not implemented as a physical class. "
|
"LayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
"It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
Convert a native LayerNorm module to colossalai layer norm module,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The LayerNorm module.
|
nn.Module: The colossalai LayerNorm module.
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -175,7 +171,8 @@ class LayerNorm(BaseLayerNorm):
|
||||||
# aggregation of these gradients is necessary during backpropagation.
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
if module.bias is not None:
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -188,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedLayerNorm is not implemented as a physical class. "
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
"It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
Convert a native LayerNorm module to FusedLayerNorm module provided by apex,
|
||||||
and optionally marking parameters for gradient aggregation.
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
module (nn.Module): The native LayerNorm module to be converted.
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes of the module
|
# get the attributes of the module
|
||||||
normalized_shape = module.normalized_shape
|
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||||
eps = module.eps
|
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||||
elementwise_affine = module.elementwise_affine
|
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
|
@ -230,116 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
ApexFusedLayerNorm = FusedLayerNormWithHook
|
||||||
except NameError:
|
except NameError:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
|
||||||
)
|
|
||||||
return module
|
|
||||||
|
|
||||||
layernorm = (
|
|
||||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
|
||||||
)
|
|
||||||
layernorm.weight = module.weight
|
|
||||||
layernorm.bias = module.bias
|
|
||||||
|
|
||||||
if sp_partial_derived:
|
|
||||||
# Since gradients are computed using only a subset of the data,
|
|
||||||
# aggregation of these gradients is necessary during backpropagation.
|
|
||||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
|
||||||
|
|
||||||
return layernorm
|
|
||||||
|
|
||||||
|
|
||||||
class CohereLayerNorm(BaseLayerNorm):
|
|
||||||
r"""
|
|
||||||
This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"CohereLayerNorm is not implemented as a physical class. "
|
|
||||||
"It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
||||||
r"""
|
|
||||||
Convert a CohereLayerNorm module to colossalai layer norm module,
|
|
||||||
and optionally marking parameters for gradient aggregation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
|
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: The LayerNorm module.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of CohereLayerNorm
|
|
||||||
"""
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
|
||||||
|
|
||||||
if sp_partial_derived:
|
|
||||||
# Since gradients are computed using only a subset of the data,
|
|
||||||
# aggregation of these gradients is necessary during backpropagation.
|
|
||||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
||||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
|
||||||
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class FusedCohereLayerNorm(BaseLayerNorm):
|
|
||||||
r"""
|
|
||||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FusedCohereLayerNorm is not implemented as a physical class. "
|
|
||||||
"It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
|
||||||
r"""
|
|
||||||
Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex,
|
|
||||||
and optionally marking parameters for gradient aggregation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted.
|
|
||||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
|
||||||
# get the attributes of the module
|
|
||||||
normalized_shape = module.weight.size(0)
|
|
||||||
eps = module.variance_epsilon
|
|
||||||
elementwise_affine = True
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
device = module.weight.device
|
|
||||||
|
|
||||||
# pick the suitable layernorm implementation
|
|
||||||
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
|
|
||||||
|
|
||||||
if use_fast_ln:
|
|
||||||
if EnableFastLayerNorm:
|
|
||||||
ApexFusedLayerNorm = FastLayerNormWithHook
|
|
||||||
else:
|
|
||||||
# fall back to the normal fused layernorm is not built
|
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
ApexFusedLayerNorm = FusedLayerNormWithHook
|
|
||||||
except NameError:
|
|
||||||
warnings.warn(
|
|
||||||
"Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
|
|
||||||
)
|
)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -347,6 +233,8 @@ class FusedCohereLayerNorm(BaseLayerNorm):
|
||||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||||
)
|
)
|
||||||
layernorm.weight = module.weight
|
layernorm.weight = module.weight
|
||||||
|
if module.bias is not None:
|
||||||
|
layernorm.bias = module.bias
|
||||||
|
|
||||||
if sp_partial_derived:
|
if sp_partial_derived:
|
||||||
# Since gradients are computed using only a subset of the data,
|
# Since gradients are computed using only a subset of the data,
|
||||||
|
|
|
@ -7,8 +7,8 @@ from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import (
|
from colossalai.shardformer.layer import (
|
||||||
CohereLayerNorm,
|
FusedLayerNorm,
|
||||||
FusedCohereLayerNorm,
|
LayerNorm,
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
PaddingEmbedding,
|
PaddingEmbedding,
|
||||||
|
@ -64,9 +64,9 @@ class CommandPolicy(Policy):
|
||||||
embedding_cls = PaddingEmbedding
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
norm_cls = FusedCohereLayerNorm
|
norm_cls = FusedLayerNorm
|
||||||
else:
|
else:
|
||||||
norm_cls = CohereLayerNorm
|
norm_cls = LayerNorm
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
|
Loading…
Reference in New Issue