ColossalAI/applications/Chat/coati/models/base/__init__.py

27 lines
747 B
Python
Raw Normal View History

from typing import Union
import torch.nn as nn
2023-03-28 12:25:36 +00:00
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.
Args:
model (nn.Module): model to get base model from
Returns:
nn.Module: the base model
"""
assert isinstance(model, (Actor, Critic, RewardModel)), \
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
return model.model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']