diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 063ea233e..faa7a90d9 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -64,7 +64,9 @@ class DPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index dd7dabfe6..f0b23afb6 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -67,7 +67,9 @@ class KTOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.ref_model = ref_model self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index 9a3adcd73..761fd305a 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -60,7 +60,9 @@ class ORPOTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss()