2023-04-27 10:41:49 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
from .actor import Actor
|
|
|
|
from .critic import Critic
|
|
|
|
from .reward_model import RewardModel
|
|
|
|
|
2023-04-27 10:41:49 +00:00
|
|
|
|
|
|
|
def get_base_model(model: nn.Module) -> nn.Module:
|
|
|
|
"""Get the base model of our wrapper classes.
|
|
|
|
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
|
|
|
|
For Critic and RewardModel, it's base model is itself.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to get base model from
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Module: the base model
|
|
|
|
"""
|
|
|
|
if isinstance(model, Actor):
|
|
|
|
return model.get_base_model()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|