Browse Source

add args.prefetch_num for benchmark

pull/5751/head
genghaozhe 6 months ago
parent
commit
b9269d962d
  1. 6
      examples/language/llama/benchmark.py

6
examples/language/llama/benchmark.py

@ -79,7 +79,7 @@ def main():
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -114,7 +114,7 @@ def main():
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
max_prefetch=10,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
)
elif args.plugin == "gemini_auto":
@ -125,6 +125,8 @@ def main():
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
)
elif args.plugin == "fsdp":

Loading…
Cancel
Save