mirror of https://github.com/hpcaitech/ColossalAI
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import torch.nn as nn
|
|
|
|
from .basepolicy import Policy
|
|
|
|
|
|
def build_policies():
|
|
r"""
|
|
Build the policies for the model
|
|
|
|
Return:
|
|
The dict for the policies
|
|
"""
|
|
auto_policy_dict = {}
|
|
|
|
from transformers import BertModel
|
|
|
|
from .bert import BertModelPolicy
|
|
auto_policy_dict[BertModel] = BertModelPolicy
|
|
|
|
from transformers import BertForPreTraining
|
|
|
|
from .bert import BertForPretrainingPolicy
|
|
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
|
|
|
|
from transformers import BertLMHeadModel
|
|
|
|
from .bert import BertLMHeadModelPolicy
|
|
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
|
|
|
|
from transformers import BertForMaskedLM
|
|
|
|
from .bert import BertForMaskedLMPolicy
|
|
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
|
|
|
from transformers import BertForNextSentencePrediction
|
|
|
|
from .bert import BertForNextSentencePredictionPolicy
|
|
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
|
|
|
|
from transformers import BertForSequenceClassification
|
|
|
|
from .bert import BertForSequenceClassificationPolicy
|
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
|
|
# from .llama import LlamaPolicy
|
|
# auto_policy_dict[LlamaModel] = LlamaPolicy
|
|
# from transformers import LlamaForSequenceClassification
|
|
# from .llama import LlamaForSequenceClassificationPolicy
|
|
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
|
# from transformers import LlamaForCausalLM
|
|
# from .llama import LlamaForCausalLMPolicy
|
|
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
|
# from transformers import GPT2Model
|
|
# from .gpt2 import GPT2Policy
|
|
# auto_policy_dict[GPT2Model] = GPT2Policy
|
|
# from transformers import GPT2LMHeadModel
|
|
# from .gpt2 import GPT2LMHeadModelPolicy
|
|
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
|
|
|
return auto_policy_dict
|
|
|
|
|
|
def get_autopolicy(model: nn.Module) -> Policy:
|
|
r"""
|
|
Return the auto policy for the model
|
|
|
|
Args:
|
|
model (:class:`nn.Module`): The model to get the auto policy
|
|
|
|
Return:
|
|
:class:`Policy`: The auto policy for the model
|
|
"""
|
|
auto_policy_dict = build_policies()
|
|
policy = auto_policy_dict.get(model.__class__, None)
|
|
if policy is None:
|
|
raise NotImplementedError(
|
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
|
)
|
|
return policy()
|
|
|
|
|
|
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
|
# model = BertForPreTraining
|
|
# policy = get_autopolicy(model)
|
|
# print(policy)
|