fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5548/head
Edenzzzz 8 months ago committed by GitHub
parent e614aa34f3
commit 7e0ec5a85c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -74,8 +74,10 @@ class ShardConfig:
self.enable_fused_normalization = True self.enable_fused_normalization = True
self.enable_flash_attention = True self.enable_flash_attention = True
self.enable_jit_fused = True self.enable_jit_fused = True
self.enable_sequence_parallelism = True # This can cause non-in-place param sharding when used without ZeRO.
self.enable_sequence_overlap = True # It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True
def _infer(self): def _infer(self):
""" """

Loading…
Cancel
Save