|
|
|
@ -79,7 +79,7 @@ def main():
|
|
|
|
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) |
|
|
|
|
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) |
|
|
|
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) |
|
|
|
|
|
|
|
|
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
colossalai.launch_from_torch() |
|
|
|
@ -114,7 +114,7 @@ def main():
|
|
|
|
|
extra_dp_size=args.extra_dp, |
|
|
|
|
enable_fused_normalization=torch.cuda.is_available(), |
|
|
|
|
enable_flash_attention=args.xformers, |
|
|
|
|
max_prefetch=10, |
|
|
|
|
max_prefetch=args.prefetch_num, |
|
|
|
|
enable_async_reduce=not args.disable_async_reduce, |
|
|
|
|
) |
|
|
|
|
elif args.plugin == "gemini_auto": |
|
|
|
@ -125,6 +125,8 @@ def main():
|
|
|
|
|
tp_size=args.tp, |
|
|
|
|
extra_dp_size=args.extra_dp, |
|
|
|
|
enable_fused_normalization=torch.cuda.is_available(), |
|
|
|
|
max_prefetch=args.prefetch_num, |
|
|
|
|
enable_async_reduce=not args.disable_async_reduce, |
|
|
|
|
enable_flash_attention=args.xformers, |
|
|
|
|
) |
|
|
|
|
elif args.plugin == "fsdp": |
|
|
|
|