ColossalAI/colossalai/shardformer/policies/auto_policy.py

225 lines
9.7 KiB
Python

import importlib
from dataclasses import dataclass
from typing import Optional
import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@dataclass
class PolicyLocation:
"""
PolicyLocation describes the location of a policy class.
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation(
file_name="bert", class_name="BertForPreTrainingPolicy"
),
"transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(
file_name="bert", class_name="BertLMHeadModelPolicy"
),
"transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation(
file_name="bert", class_name="BertForMaskedLMPolicy"
),
"transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation(
file_name="bert", class_name="BertForSequenceClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation(
file_name="bert", class_name="BertForTokenClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation(
file_name="bert", class_name="BertForNextSentencePredictionPolicy"
),
"transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(
file_name="bert", class_name="BertForMultipleChoicePolicy"
),
"transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation(
file_name="bert", class_name="BertForQuestionAnsweringPolicy"
),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaForCausalLMPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation(
file_name="llama", class_name="LlamaForSequenceClassificationPolicy"
),
# T5
"transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation(
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
file_name="gpt2", class_name="GPT2LMHeadModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation(
file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation(
file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"
),
# ViT
"transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(
file_name="vit", class_name="ViTForImageClassificationPolicy"
),
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation(
file_name="vit", class_name="ViTForMaskedImageModelingPolicy"
),
# OPT
"transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
"transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation(
file_name="opt", class_name="OPTForCausalLMPolicy"
),
"transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation(
file_name="opt", class_name="OPTForSequenceClassificationPolicy"
),
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(
file_name="opt", class_name="OPTForQuestionAnsweringPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomForCausalLMPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation(
file_name="bloom", class_name="BloomForSequenceClassificationPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation(
file_name="bloom", class_name="BloomForTokenClassificationPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(
file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"
),
# Whisper
"transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation(
file_name="whisper", class_name="WhisperModelPolicy"
),
"transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation(
file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"
),
"transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation(
file_name="whisper", class_name="WhisperForAudioClassificationPolicy"
),
# Sam
"transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
# Blip2
"transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation(
file_name="blip2", class_name="Blip2ModelPolicy"
),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation(
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
}
_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
# ChatGLM2
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
),
}
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
if inference_only:
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r"""
Return the auto policy for the model
Args:
model (:class:`nn.Module`): The model to get the auto policy
Return:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
inference_only = shard_config.extra_kwargs.get("inference_only", False)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, inference_only)
return policy()