mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/examples/train_reward_model.py code style (#4271)
parent
a50d39a143
commit
1ce997daaf
|
@ -150,9 +150,7 @@ def train(args):
|
|||
pin_memory=True)
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
|
||||
strategy_dict = strategy.prepare(
|
||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
||||
)
|
||||
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
|
@ -163,9 +161,7 @@ def train(args):
|
|||
loss_fn=loss_fn,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader)
|
||||
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
|
Loading…
Reference in New Issue