mirror of https://github.com/hpcaitech/ColossalAI
[NFC] policy applications/Chat/examples/ray/mmmt_prompt.py code style (#4250)
parent
77c469e1ba
commit
dee1c96344
|
@ -87,8 +87,8 @@ def main(args):
|
|||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
|
@ -161,12 +161,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--prompt_path', type=str, default=None)
|
||||
parser.add_argument('--num_makers', type=int, default=1)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument(
|
||||
'--trainer_strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
|
Loading…
Reference in New Issue