2023-05-22 07:02:17 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
def build_policies():
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Build the policies for the model
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
Return:
|
|
|
|
The dict for the policies
|
|
|
|
"""
|
|
|
|
auto_policy_dict = {}
|
|
|
|
|
2023-06-07 08:09:40 +00:00
|
|
|
from transformers import BertForMaskedLM
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
from .bert import BertForMaskedLMPolicy
|
|
|
|
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
|
|
|
|
2023-06-07 08:09:40 +00:00
|
|
|
from transformers import BertForSequenceClassification
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
from .bert import BertForSequenceClassificationPolicy
|
|
|
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-07 08:09:40 +00:00
|
|
|
from transformers import GPT2Model
|
|
|
|
|
|
|
|
from .gpt2 import GPT2Policy
|
|
|
|
auto_policy_dict[GPT2Model] = GPT2Policy
|
|
|
|
|
|
|
|
from transformers import GPT2LMHeadModel
|
|
|
|
|
|
|
|
from .gpt2 import GPT2LMHeadModelPolicy
|
|
|
|
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
return auto_policy_dict
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
def get_autopolicy(model: nn.Module):
|
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Return the auto policy for the model
|
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
model (:class:`nn.Module`): The model to get the auto policy
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Return:
|
2023-05-24 02:26:46 +00:00
|
|
|
:class:`Policy`: The auto policy for the model
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
|
|
|
auto_policy_dict = build_policies()
|
|
|
|
policy = auto_policy_dict.get(model.__class__, None)
|
2023-05-24 02:26:46 +00:00
|
|
|
if policy is None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
|
|
|
)
|
2023-05-22 07:02:17 +00:00
|
|
|
return policy
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
|
|
|
# model = BertForPreTraining
|
|
|
|
# policy = get_autopolicy(model)
|
|
|
|
# print(policy)
|