diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 33e500034..f17fad1b6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d -from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm +from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -23,8 +23,6 @@ __all__ = [ "RMSNorm", "FusedLayerNorm", "FusedRMSNorm", - "CohereLayerNorm", - "FusedCohereLayerNorm", "FusedLinear1D_Col", "ParallelModule", "PaddingEmbedding", diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 34a126904..59e1da9fc 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,7 +4,6 @@ import warnings from abc import ABC, abstractmethod import torch.nn as nn -from transformers.models.cohere.modeling_cohere import CohereLayerNorm from colossalai.lazy import LazyInitContext @@ -141,32 +140,29 @@ class RMSNorm(BaseLayerNorm): class LayerNorm(BaseLayerNorm): r""" - This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface. + This is a wrapper around native LayerNorm. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( "LayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module." + "It is meant to be used only with the from_native_module interface to convert a native LayerNorm module to colossalai layer norm module." ) @staticmethod - def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: + def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to colossalai layer norm module, + Convert a native LayerNorm module to colossalai layer norm module, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: - nn.Module: The LayerNorm module. + nn.Module: The colossalai LayerNorm module. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." LazyInitContext.materialize(module) @@ -175,7 +171,8 @@ class LayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + if module.bias is not None: + SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) return module @@ -188,140 +185,29 @@ class FusedLayerNorm(BaseLayerNorm): def __init__(self) -> None: raise NotImplementedError( "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." + "It is meant to be used only with the from_native_module interface convert a native LayerNorm module to FusedLayerNorm module provided by apex." ) @staticmethod def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" - Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, + Convert a native LayerNorm module to FusedLayerNorm module provided by apex, and optionally marking parameters for gradient aggregation. Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. + module (nn.Module): The native LayerNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: nn.Module: Union[FastLayerNorm, FusedLayerNorm]. - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. """ LazyInitContext.materialize(module) # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine - dtype = module.weight.dtype - device = module.weight.device - - # pick the suitable layernorm implementation - use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE - - if use_fast_ln: - if EnableFastLayerNorm: - ApexFusedLayerNorm = FastLayerNormWithHook - else: - # fall back to the normal fused layernorm is not built - ApexFusedLayerNorm = FusedLayerNormWithHook - else: - try: - ApexFusedLayerNorm = FusedLayerNormWithHook - except NameError: - warnings.warn( - "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." - ) - return module - - layernorm = ( - ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) - ) - layernorm.weight = module.weight - layernorm.bias = module.bias - - if sp_partial_derived: - # Since gradients are computed using only a subset of the data, - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) - - return layernorm - - -class CohereLayerNorm(BaseLayerNorm): - r""" - This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface. - """ - - def __init__(self) -> None: - raise NotImplementedError( - "CohereLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module." - ) - - @staticmethod - def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: - r""" - Convert a CohereLayerNorm module to colossalai layer norm module, - and optionally marking parameters for gradient aggregation. - - Args: - module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. - sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. - - Returns: - nn.Module: The LayerNorm module. - - Raises: - AssertionError: If the provided module is not an instance of CohereLayerNorm - """ - - LazyInitContext.materialize(module) - - if sp_partial_derived: - # Since gradients are computed using only a subset of the data, - # aggregation of these gradients is necessary during backpropagation. - # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. - SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - - return module - - -class FusedCohereLayerNorm(BaseLayerNorm): - r""" - This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. - """ - - def __init__(self) -> None: - raise NotImplementedError( - "FusedCohereLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex." - ) - - @staticmethod - def from_native_module(module: CohereLayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: - r""" - Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex, - and optionally marking parameters for gradient aggregation. - - Args: - module (transformers.models.cohere.CohereLayerNorm): The CohereLayerNorm module to be converted. - sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. - - Returns: - nn.Module: Union[FastLayerNorm, FusedLayerNorm]. - - Raises: - AssertionError: If the provided module is not an instance of transformers.models.cohere.CohereLayerNorm. - """ - - LazyInitContext.materialize(module) - # get the attributes of the module - normalized_shape = module.weight.size(0) - eps = module.variance_epsilon - elementwise_affine = True + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) dtype = module.weight.dtype device = module.weight.device @@ -339,7 +225,7 @@ class FusedCohereLayerNorm(BaseLayerNorm): ApexFusedLayerNorm = FusedLayerNormWithHook except NameError: warnings.warn( - "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead." + "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead." ) return module @@ -347,6 +233,8 @@ class FusedCohereLayerNorm(BaseLayerNorm): ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) layernorm.weight = module.weight + if module.bias is not None: + layernorm.bias = module.bias if sp_partial_derived: # Since gradients are computed using only a subset of the data, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 6c4785912..e2a367f74 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -7,8 +7,8 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - CohereLayerNorm, - FusedCohereLayerNorm, + FusedLayerNorm, + LayerNorm, Linear1D_Col, Linear1D_Row, PaddingEmbedding, @@ -64,9 +64,9 @@ class CommandPolicy(Policy): embedding_cls = PaddingEmbedding if self.shard_config.enable_fused_normalization: - norm_cls = FusedCohereLayerNorm + norm_cls = FusedLayerNorm else: - norm_cls = CohereLayerNorm + norm_cls = LayerNorm if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False