diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index c0ac7b177..22f70e485 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -35,6 +35,8 @@ def train(args): strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -168,7 +170,7 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], default='naive') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--pretrain', type=str, default=None)