fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5548/head
Edenzzzz 2024-04-02 20:11:18 +08:00 committed by GitHub
parent e614aa34f3
commit 7e0ec5a85c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -74,8 +74,10 @@ class ShardConfig:
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True
# This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True
def _infer(self):
"""