from typing import Union import torch.nn as nn from .actor import Actor from .critic import Critic from .reward_model import RewardModel def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: """Get the base model of our wrapper classes. For Actor, Critic and RewardModel, return ``model.model``, it's usually a ``transformers.PreTrainedModel``. Args: model (nn.Module): model to get base model from Returns: nn.Module: the base model """ assert isinstance(model, (Actor, Critic, RewardModel)), \ f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.' return model.model __all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']