2023-06-13 05:31:56 +00:00
|
|
|
from typing import Union
|
|
|
|
|
2023-04-27 10:41:49 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
from .actor import Actor
|
|
|
|
from .critic import Critic
|
|
|
|
from .reward_model import RewardModel
|
|
|
|
|
2023-04-27 10:41:49 +00:00
|
|
|
|
2023-06-13 05:31:56 +00:00
|
|
|
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
2023-04-27 10:41:49 +00:00
|
|
|
"""Get the base model of our wrapper classes.
|
2023-06-13 05:31:56 +00:00
|
|
|
For Actor, Critic and RewardModel, return ``model.model``,
|
|
|
|
it's usually a ``transformers.PreTrainedModel``.
|
2023-04-27 10:41:49 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): model to get base model from
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Module: the base model
|
|
|
|
"""
|
2023-06-13 05:31:56 +00:00
|
|
|
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
|
|
|
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
|
|
|
return model.model
|
2023-04-27 10:41:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|