|
|
|
@ -76,9 +76,11 @@ def main(args):
|
|
|
|
|
if args.strategy == "ddp": |
|
|
|
|
strategy = DDPStrategy() |
|
|
|
|
elif args.strategy == "colossalai_gemini": |
|
|
|
|
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5) |
|
|
|
|
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) |
|
|
|
|
elif args.strategy == "colossalai_gemini_cpu": |
|
|
|
|
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) |
|
|
|
|
strategy = GeminiStrategy( |
|
|
|
|
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 |
|
|
|
|
) |
|
|
|
|
elif args.strategy == "colossalai_zero2": |
|
|
|
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") |
|
|
|
|
elif args.strategy == "colossalai_zero2_cpu": |
|
|
|
|