From fba04e857b57abc54ba4864cbfb3af0461e2c5e7 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Sat, 25 May 2024 14:55:09 +0000 Subject: [PATCH 1/3] [bugs] fix args.profile=False DummyProfiler errro --- examples/language/performance_evaluator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 0b147b7ea..99df8f1da 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -36,6 +36,12 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def step(self): self.step_number += 1 + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + if enable_flag: return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], From b9269d962dff742df667ae19000f63622b45f56b Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Sat, 25 May 2024 14:55:50 +0000 Subject: [PATCH 2/3] add args.prefetch_num for benchmark --- examples/language/llama/benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 712703b45..b71203518 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -79,7 +79,7 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) - + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() colossalai.launch_from_torch() @@ -114,7 +114,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, - max_prefetch=10, + max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": @@ -125,6 +125,8 @@ def main(): tp_size=args.tp, extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), + max_prefetch=args.prefetch_num, + enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": From 87665d79228df9e8e40363e731874939f3b66b2f Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 27 May 2024 06:03:53 +0000 Subject: [PATCH 3/3] correct argument help message --- examples/language/llama/benchmark.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index b71203518..8d4dae314 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -78,7 +78,9 @@ def main(): parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) - parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument( + "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False + ) parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args()