ColossalAI/diff.output

60 lines
2.9 KiB
Plaintext

diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 5aa21260..01453a05 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm):
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm):
LazyInitContext.materialize(module)
# get the attributes of the module
- normalized_shape = module.normalized_shape
- eps = module.eps
- elementwise_affine = module.elementwise_affine
+ # normalized_shape = module.normalized_shape
+ # eps = module.eps
+ # elementwise_affine = module.elementwise_affine
+ normalized_shape = module.weight.size(0)
+ eps = module.variance_epsilon
+ elementwise_affine = True
dtype = module.weight.dtype
device = module.weight.device
@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
index 6075f836..a7166e38 100644
--- a/tests/test_shardformer/test_model/test_shard_command.py
+++ b/tests/test_shardformer/test_model/test_shard_command.py
@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
],
)
def run_command_test(test_config):
+ print(test_config)
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():