diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 5aa21260..01453a05 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm): Raises: AssertionError: If the provided module is not an instance of nn.LayerNorm. """ - assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." + # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm." LazyInitContext.materialize(module) @@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) + # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias) return module @@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm): LazyInitContext.materialize(module) # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine + # normalized_shape = module.normalized_shape + # eps = module.eps + # elementwise_affine = module.elementwise_affine + normalized_shape = module.weight.size(0) + eps = module.variance_epsilon + elementwise_affine = True dtype = module.weight.dtype device = module.weight.device @@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm): # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight) - SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) + # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias) return layernorm diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 6075f836..a7166e38 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ], ) def run_command_test(test_config): + print(test_config) sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():