[NFC] polish applications/Chat/examples/train_reward_model.py code style (#4271)

pull/4338/head
Xu Kai 2023-07-18 18:01:52 +08:00 committed by binmakeswell
parent a50d39a143
commit 1ce997daaf
1 changed files with 2 additions and 6 deletions

View File

@ -150,9 +150,7 @@ def train(args):
pin_memory=True)
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
@ -163,9 +161,7 @@ def train(args):
loss_fn=loss_fn,
max_epochs=args.max_epochs)
trainer.fit(train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader)
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks