mirror of https://github.com/hpcaitech/ColossalAI
60 lines
2.9 KiB
Plaintext
60 lines
2.9 KiB
Plaintext
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
|
|
index 5aa21260..01453a05 100644
|
|
--- a/colossalai/shardformer/layer/normalization.py
|
|
+++ b/colossalai/shardformer/layer/normalization.py
|
|
@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm):
|
|
Raises:
|
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
|
"""
|
|
- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
|
|
|
LazyInitContext.materialize(module)
|
|
|
|
@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm):
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
|
- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
|
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
|
|
|
return module
|
|
|
|
@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm):
|
|
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes of the module
|
|
- normalized_shape = module.normalized_shape
|
|
- eps = module.eps
|
|
- elementwise_affine = module.elementwise_affine
|
|
+ # normalized_shape = module.normalized_shape
|
|
+ # eps = module.eps
|
|
+ # elementwise_affine = module.elementwise_affine
|
|
+ normalized_shape = module.weight.size(0)
|
|
+ eps = module.variance_epsilon
|
|
+ elementwise_affine = True
|
|
dtype = module.weight.dtype
|
|
device = module.weight.device
|
|
|
|
@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm):
|
|
# aggregation of these gradients is necessary during backpropagation.
|
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
|
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
|
- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
|
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
|
|
|
return layernorm
|
|
|
|
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
|
|
index 6075f836..a7166e38 100644
|
|
--- a/tests/test_shardformer/test_model/test_shard_command.py
|
|
+++ b/tests/test_shardformer/test_model/test_shard_command.py
|
|
@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|
],
|
|
)
|
|
def run_command_test(test_config):
|
|
+ print(test_config)
|
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|