diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4b78624f0..a7b552c9e 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -250,7 +250,7 @@ def main(): plugin = None if args.distplan.startswith("CAI_ZeRO"): plugin = LowLevelZeroPlugin(stage=zero_stage, - reduce_bucket_size_in_m=12 * 1024 * 1024, + reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini":