[Gemini] add some code for reduce-scatter overlap, chunk prefetch in llama benchmark. (#5751)

* [bugs] fix args.profile=False DummyProfiler errro

* add args.prefetch_num for benchmark
pull/5754/head^2
Haze188 6 months ago committed by GitHub
parent ca674549e0
commit 4d097def96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -79,7 +79,7 @@ def main():
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -114,7 +114,7 @@ def main():
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
max_prefetch=10, max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
@ -125,6 +125,8 @@ def main():
tp_size=args.tp, tp_size=args.tp,
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
) )
elif args.plugin == "fsdp": elif args.plugin == "fsdp":

@ -36,6 +36,12 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def step(self): def step(self):
self.step_number += 1 self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag: if enable_flag:
return profile( return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],

Loading…
Cancel
Save