update rm

pull/6007/head
Tong Li 2024-08-12 11:27:42 +00:00
parent 38c84a1aa0
commit 5a24b0dc31
2 changed files with 4 additions and 2 deletions

View File

@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
model: Any,
booster: Booster,
optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
@ -59,7 +60,7 @@ class RewardModelTrainer(SLTrainer):
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch)
self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)

View File

@ -262,6 +262,7 @@ def train(args):
model,
booster,
optim,
plugin,
lr_scheduler,
tokenizer,
loss_fn=loss_fn,