From 7e0ec5a85c73fcc5666b9d218e43865141587dde Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Apr 2024 20:11:18 +0800 Subject: [PATCH] fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz --- colossalai/shardformer/shard/shard_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 646b61193..ce78a7e94 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -74,8 +74,10 @@ class ShardConfig: self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True + # This can cause non-in-place param sharding when used without ZeRO. + # It may also slow down training when seq len is small. Plz enable manually. + # self.enable_sequence_parallelism = True + # self.enable_sequence_overlap = True def _infer(self): """