@ -4,7 +4,6 @@ import warnings
from abc import ABC , abstractmethod
from abc import ABC , abstractmethod
import torch . nn as nn
import torch . nn as nn
from transformers . models . cohere . modeling_cohere import CohereLayerNorm
from colossalai . lazy import LazyInitContext
from colossalai . lazy import LazyInitContext
@ -141,32 +140,29 @@ class RMSNorm(BaseLayerNorm):
class LayerNorm ( BaseLayerNorm ) :
class LayerNorm ( BaseLayerNorm ) :
r """
r """
This is a wrapper around the torch . nn . LayerNorm . It is meant to be used only with the from_native_module interface .
This is a wrapper around native LayerNorm . It is meant to be used only with the from_native_module interface .
"""
"""
def __init__ ( self ) - > None :
def __init__ ( self ) - > None :
raise NotImplementedError (
raise NotImplementedError (
" LayerNorm is not implemented as a physical class. "
" LayerNorm is not implemented as a physical class. "
" It is meant to be used only with the from_native_module interface to convert a native pytorch layer n orm module to colossalai layer norm module."
" It is meant to be used only with the from_native_module interface to convert a native LayerN orm module to colossalai layer norm module."
)
)
@staticmethod
@staticmethod
def from_native_module ( module : nn . LayerNorm , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
def from_native_module ( module : nn . Module , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
r """
r """
Convert a native pytorch layer n orm module to colossalai layer norm module ,
Convert a native LayerN orm module to colossalai layer norm module ,
and optionally marking parameters for gradient aggregation .
and optionally marking parameters for gradient aggregation .
Args :
Args :
module ( nn . LayerNorm) : The native PyTorch LayerNorm module to be converted .
module ( nn . Module) : The native LayerNorm module to be converted .
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
Returns :
Returns :
nn . Module : The LayerNorm module .
nn . Module : The colossalai LayerNorm module .
Raises :
AssertionError : If the provided module is not an instance of nn . LayerNorm .
"""
"""
assert isinstance ( module , nn . LayerNorm ) , " Only support conversion from nn.LayerNorm. "
LazyInitContext . materialize ( module )
LazyInitContext . materialize ( module )
@ -175,6 +171,7 @@ class LayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils . marked_as_sp_partial_derived_param ( module . weight )
SeqParallelUtils . marked_as_sp_partial_derived_param ( module . weight )
if module . bias is not None :
SeqParallelUtils . marked_as_sp_partial_derived_param ( module . bias )
SeqParallelUtils . marked_as_sp_partial_derived_param ( module . bias )
return module
return module
@ -188,31 +185,29 @@ class FusedLayerNorm(BaseLayerNorm):
def __init__ ( self ) - > None :
def __init__ ( self ) - > None :
raise NotImplementedError (
raise NotImplementedError (
" FusedLayerNorm is not implemented as a physical class. "
" FusedLayerNorm is not implemented as a physical class. "
" It is meant to be used only with the from_native_module interface convert a native pytorch layer n orm module to FusedLayerNorm module provided by apex."
" It is meant to be used only with the from_native_module interface convert a native LayerN orm module to FusedLayerNorm module provided by apex."
)
)
@staticmethod
@staticmethod
def from_native_module ( module : nn . LayerNorm , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
def from_native_module ( module : nn . LayerNorm , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
r """
r """
Convert a native pytorch layer n orm module to FusedLayerNorm module provided by apex ,
Convert a native LayerN orm module to FusedLayerNorm module provided by apex ,
and optionally marking parameters for gradient aggregation .
and optionally marking parameters for gradient aggregation .
Args :
Args :
module ( nn . LayerNorm) : The native PyTorch LayerNorm module to be converted .
module ( nn . Module) : The native LayerNorm module to be converted .
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
Returns :
Returns :
nn . Module : Union [ FastLayerNorm , FusedLayerNorm ] .
nn . Module : Union [ FastLayerNorm , FusedLayerNorm ] .
Raises :
AssertionError : If the provided module is not an instance of nn . LayerNorm .
"""
"""
LazyInitContext . materialize ( module )
LazyInitContext . materialize ( module )
# get the attributes of the module
# get the attributes of the module
normalized_shape = module . normalized_shape
normalized_shape = getattr ( module , " normalized_shape " , module . weight . shape [ 0 ] )
eps = module . eps
eps = module . variance_epsilon if hasattr ( module , " variance_epsilon " ) else module . eps
elementwise_affine = module . elementwise_affine
elementwise_affine = getattr ( module , " elementwise_affine " , True )
dtype = module . weight . dtype
dtype = module . weight . dtype
device = module . weight . device
device = module . weight . device
@ -230,7 +225,7 @@ class FusedLayerNorm(BaseLayerNorm):
ApexFusedLayerNorm = FusedLayerNormWithHook
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError :
except NameError :
warnings . warn (
warnings . warn (
" Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
" Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using native layernorm instead."
)
)
return module
return module
@ -238,6 +233,7 @@ class FusedLayerNorm(BaseLayerNorm):
ApexFusedLayerNorm ( normalized_shape , eps = eps , elementwise_affine = elementwise_affine ) . to ( dtype ) . to ( device )
ApexFusedLayerNorm ( normalized_shape , eps = eps , elementwise_affine = elementwise_affine ) . to ( dtype ) . to ( device )
)
)
layernorm . weight = module . weight
layernorm . weight = module . weight
if module . bias is not None :
layernorm . bias = module . bias
layernorm . bias = module . bias
if sp_partial_derived :
if sp_partial_derived :
@ -250,114 +246,6 @@ class FusedLayerNorm(BaseLayerNorm):
return layernorm
return layernorm
class CohereLayerNorm ( BaseLayerNorm ) :
r """
This is a wrapper around the transformers . models . cohere . CohereLayerNorm . It is meant to be used only with the from_native_module interface .
"""
def __init__ ( self ) - > None :
raise NotImplementedError (
" CohereLayerNorm is not implemented as a physical class. "
" It is meant to be used only with the from_native_module interface to convert a transformers.models.cohere.CohereLayerNorm module to colossalai layer norm module. "
)
@staticmethod
def from_native_module ( module : CohereLayerNorm , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
r """
Convert a CohereLayerNorm module to colossalai layer norm module ,
and optionally marking parameters for gradient aggregation .
Args :
module ( transformers . models . cohere . CohereLayerNorm ) : The CohereLayerNorm module to be converted .
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
Returns :
nn . Module : The LayerNorm module .
Raises :
AssertionError : If the provided module is not an instance of CohereLayerNorm
"""
LazyInitContext . materialize ( module )
if sp_partial_derived :
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils . marked_as_sp_partial_derived_param ( module . weight )
return module
class FusedCohereLayerNorm ( BaseLayerNorm ) :
r """
This is a wrapper around the apex fused layernorm implementation . It is meant to be used only with the from_native_module interface .
"""
def __init__ ( self ) - > None :
raise NotImplementedError (
" FusedCohereLayerNorm is not implemented as a physical class. "
" It is meant to be used only with the from_native_module interface convert a transformers.models.cohere.CohereLayerNorm module to FusedLayerNorm module provided by apex. "
)
@staticmethod
def from_native_module ( module : CohereLayerNorm , sp_partial_derived : bool = False , * args , * * kwargs ) - > nn . Module :
r """
Convert a CohereLayerNorm module to FusedLayerNorm module provided by apex ,
and optionally marking parameters for gradient aggregation .
Args :
module ( transformers . models . cohere . CohereLayerNorm ) : The CohereLayerNorm module to be converted .
sp_partial_derived ( bool ) : Whether this module ' s gradients are partially derived in sequence parallelism.
Returns :
nn . Module : Union [ FastLayerNorm , FusedLayerNorm ] .
Raises :
AssertionError : If the provided module is not an instance of transformers . models . cohere . CohereLayerNorm .
"""
LazyInitContext . materialize ( module )
# get the attributes of the module
normalized_shape = module . weight . size ( 0 )
eps = module . variance_epsilon
elementwise_affine = True
dtype = module . weight . dtype
device = module . weight . device
# pick the suitable layernorm implementation
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
if use_fast_ln :
if EnableFastLayerNorm :
ApexFusedLayerNorm = FastLayerNormWithHook
else :
# fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook
else :
try :
ApexFusedLayerNorm = FusedLayerNormWithHook
except NameError :
warnings . warn (
" Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead. "
)
return module
layernorm = (
ApexFusedLayerNorm ( normalized_shape , eps = eps , elementwise_affine = elementwise_affine ) . to ( dtype ) . to ( device )
)
layernorm . weight = module . weight
if sp_partial_derived :
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils . marked_as_sp_partial_derived_param ( layernorm . weight )
SeqParallelUtils . marked_as_sp_partial_derived_param ( layernorm . bias )
return layernorm
class FusedRMSNorm ( BaseLayerNorm ) :
class FusedRMSNorm ( BaseLayerNorm ) :
"""
"""
This is a wrapper around the apex fused rms norm implementation . It is meant to be used only with the from_native_module interface .
This is a wrapper around the apex fused rms norm implementation . It is meant to be used only with the from_native_module interface .