Browse Source

Merge pull request #6029 from hpcaitech/flybird11111-patch-1

Update train_dpo.py
pull/6012/head
Wang Binluo 3 months ago committed by GitHub
parent
commit
0bf46c54af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      applications/ColossalChat/examples/training_scripts/train_dpo.py

6
applications/ColossalChat/examples/training_scripts/train_dpo.py

@ -279,10 +279,7 @@ def train(args):
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
<<<<<<< HEAD
=======
apply_loss_mask=not args.disable_loss_mask,
>>>>>>> main
)
trainer.fit(
@ -351,10 +348,7 @@ if __name__ == "__main__":
default=False,
help="Disable the reference model (enabled by default)",
)
<<<<<<< HEAD
=======
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
>>>>>>> main
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")

Loading…
Cancel
Save