2023-06-19 05:53:17 +00:00
import importlib
from dataclasses import dataclass
2023-09-11 17:22:56 +00:00
from typing import Optional
2023-06-19 05:53:17 +00:00
2023-05-22 07:02:17 +00:00
import torch . nn as nn
2023-11-03 05:32:43 +00:00
from . . shard . shard_config import ShardConfig
2023-07-05 07:13:00 +00:00
from . base_policy import Policy
2023-06-15 09:55:42 +00:00
2023-06-30 02:56:29 +00:00
__all__ = [ " PolicyLocation " , " get_autopolicy " , " import_policy " ]
2023-05-24 02:26:46 +00:00
2023-06-19 05:53:17 +00:00
@dataclass
class PolicyLocation :
2023-05-22 07:02:17 +00:00
"""
2023-06-19 05:53:17 +00:00
PolicyLocation describes the location of a policy class .
2023-06-15 09:56:51 +00:00
2023-06-19 05:53:17 +00:00
Args :
file_name ( str ) : The file name of the policy under colossalai . shardformer . policies
class_name ( str ) : The class name of the policy class
"""
2023-09-19 06:20:26 +00:00
2023-06-19 05:53:17 +00:00
file_name : str
class_name : str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
2023-09-19 06:20:26 +00:00
" transformers.models.bert.modeling_bert.BertModel " : PolicyLocation ( file_name = " bert " , class_name = " BertModelPolicy " ) ,
" transformers.models.bert.modeling_bert.BertForPreTraining " : PolicyLocation (
file_name = " bert " , class_name = " BertForPreTrainingPolicy "
) ,
" transformers.models.bert.modeling_bert.BertLMHeadModel " : PolicyLocation (
file_name = " bert " , class_name = " BertLMHeadModelPolicy "
) ,
" transformers.models.bert.modeling_bert.BertForMaskedLM " : PolicyLocation (
file_name = " bert " , class_name = " BertForMaskedLMPolicy "
) ,
" transformers.models.bert.modeling_bert.BertForSequenceClassification " : PolicyLocation (
file_name = " bert " , class_name = " BertForSequenceClassificationPolicy "
) ,
" transformers.models.bert.modeling_bert.BertForTokenClassification " : PolicyLocation (
file_name = " bert " , class_name = " BertForTokenClassificationPolicy "
) ,
" transformers.models.bert.modeling_bert.BertForNextSentencePrediction " : PolicyLocation (
file_name = " bert " , class_name = " BertForNextSentencePredictionPolicy "
) ,
" transformers.models.bert.modeling_bert.BertForMultipleChoice " : PolicyLocation (
file_name = " bert " , class_name = " BertForMultipleChoicePolicy "
) ,
" transformers.models.bert.modeling_bert.BertForQuestionAnswering " : PolicyLocation (
file_name = " bert " , class_name = " BertForQuestionAnsweringPolicy "
) ,
2023-06-19 05:53:17 +00:00
# LLaMA
2023-09-19 06:20:26 +00:00
" transformers.models.llama.modeling_llama.LlamaModel " : PolicyLocation (
file_name = " llama " , class_name = " LlamaModelPolicy "
) ,
" transformers.models.llama.modeling_llama.LlamaForCausalLM " : PolicyLocation (
file_name = " llama " , class_name = " LlamaForCausalLMPolicy "
) ,
" transformers.models.llama.modeling_llama.LlamaForSequenceClassification " : PolicyLocation (
file_name = " llama " , class_name = " LlamaForSequenceClassificationPolicy "
) ,
2023-06-19 05:53:17 +00:00
# T5
2023-09-19 06:20:26 +00:00
" transformers.models.t5.modeling_t5.T5Model " : PolicyLocation ( file_name = " t5 " , class_name = " T5ModelPolicy " ) ,
" transformers.models.t5.modeling_t5.T5ForConditionalGeneration " : PolicyLocation (
file_name = " t5 " , class_name = " T5ForConditionalGenerationPolicy "
) ,
" transformers.models.t5.modeling_t5.T5EncoderModel " : PolicyLocation ( file_name = " t5 " , class_name = " T5EncoderPolicy " ) ,
2023-06-19 05:53:17 +00:00
# GPT2
2023-09-19 06:20:26 +00:00
" transformers.models.gpt2.modeling_gpt2.GPT2Model " : PolicyLocation ( file_name = " gpt2 " , class_name = " GPT2ModelPolicy " ) ,
" transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel " : PolicyLocation (
file_name = " gpt2 " , class_name = " GPT2LMHeadModelPolicy "
) ,
" transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel " : PolicyLocation (
file_name = " gpt2 " , class_name = " GPT2DoubleHeadsModelPolicy "
) ,
" transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering " : PolicyLocation (
file_name = " gpt2 " , class_name = " GPT2ForQuestionAnsweringPolicy "
) ,
" transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification " : PolicyLocation (
file_name = " gpt2 " , class_name = " GPT2ForTokenClassificationPolicy "
) ,
" transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification " : PolicyLocation (
file_name = " gpt2 " , class_name = " GPT2ForSequenceClassificationPolicy "
) ,
2023-07-25 07:02:29 +00:00
# ViT
2023-09-19 06:20:26 +00:00
" transformers.models.vit.modeling_vit.ViTModel " : PolicyLocation ( file_name = " vit " , class_name = " ViTModelPolicy " ) ,
" transformers.models.vit.modeling_vit.ViTForImageClassification " : PolicyLocation (
file_name = " vit " , class_name = " ViTForImageClassificationPolicy "
) ,
" transformers.models.vit.modeling_vit.ViTForMaskedImageModeling " : PolicyLocation (
file_name = " vit " , class_name = " ViTForMaskedImageModelingPolicy "
) ,
2023-06-27 09:39:29 +00:00
# OPT
2023-09-19 06:20:26 +00:00
" transformers.models.opt.modeling_opt.OPTModel " : PolicyLocation ( file_name = " opt " , class_name = " OPTModelPolicy " ) ,
" transformers.models.opt.modeling_opt.OPTForCausalLM " : PolicyLocation (
file_name = " opt " , class_name = " OPTForCausalLMPolicy "
) ,
" transformers.models.opt.modeling_opt.OPTForSequenceClassification " : PolicyLocation (
file_name = " opt " , class_name = " OPTForSequenceClassificationPolicy "
) ,
" transformers.models.opt.modeling_opt.OPTForQuestionAnswering " : PolicyLocation (
file_name = " opt " , class_name = " OPTForQuestionAnsweringPolicy "
) ,
2023-06-28 07:04:35 +00:00
# Bloom
2023-09-19 06:20:26 +00:00
" transformers.models.bloom.modeling_bloom.BloomModel " : PolicyLocation (
file_name = " bloom " , class_name = " BloomModelPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForCausalLM " : PolicyLocation (
file_name = " bloom " , class_name = " BloomForCausalLMPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForSequenceClassification " : PolicyLocation (
file_name = " bloom " , class_name = " BloomForSequenceClassificationPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForTokenClassification " : PolicyLocation (
file_name = " bloom " , class_name = " BloomForTokenClassificationPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering " : PolicyLocation (
file_name = " bloom " , class_name = " BloomForQuestionAnsweringPolicy "
) ,
2023-07-17 06:25:32 +00:00
# Whisper
2023-09-19 06:20:26 +00:00
" transformers.models.whisper.modeling_whisper.WhisperModel " : PolicyLocation (
file_name = " whisper " , class_name = " WhisperModelPolicy "
) ,
" transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration " : PolicyLocation (
file_name = " whisper " , class_name = " WhisperForConditionalGenerationPolicy "
) ,
" transformers.models.whisper.modeling_whisper.WhisperForAudioClassification " : PolicyLocation (
file_name = " whisper " , class_name = " WhisperForAudioClassificationPolicy "
) ,
2023-07-14 07:56:59 +00:00
# Sam
2023-09-19 06:20:26 +00:00
" transformers.models.sam.modeling_sam.SamModel " : PolicyLocation ( file_name = " sam " , class_name = " SamModelPolicy " ) ,
2023-07-25 06:29:10 +00:00
# Blip2
2023-09-19 06:20:26 +00:00
" transformers.models.blip_2.modeling_blip_2.Blip2Model " : PolicyLocation (
file_name = " blip2 " , class_name = " Blip2ModelPolicy "
) ,
" transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration " : PolicyLocation (
file_name = " blip2 " , class_name = " Blip2ForConditionalGenerationPolicy "
) ,
2023-08-11 07:43:23 +00:00
# ChatGLM
2023-09-19 06:20:26 +00:00
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLMModelPolicy "
) ,
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLMForConditionalGenerationPolicy "
) ,
2023-06-19 05:53:17 +00:00
}
2023-09-11 17:22:56 +00:00
_INFER_POLICY_LIST = {
# LlaMa
2023-09-19 06:20:26 +00:00
" transformers.models.llama.modeling_llama.LlamaModel " : PolicyLocation (
file_name = " llama " , class_name = " LlamaModelInferPolicy "
) ,
" transformers.models.llama.modeling_llama.LlamaForCausalLM " : PolicyLocation (
file_name = " llama " , class_name = " LlamaModelInferPolicy "
) ,
2023-09-11 17:22:56 +00:00
# Bloom
2023-09-19 06:20:26 +00:00
" transformers.models.bloom.modeling_bloom.BloomModel " : PolicyLocation (
file_name = " bloom " , class_name = " BloomModelInferPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForCausalLM " : PolicyLocation (
file_name = " bloom " , class_name = " BloomModelInferPolicy "
) ,
2023-09-22 03:12:50 +00:00
# ChatGLM2
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLM2InferPolicy "
) ,
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLM2ForConditionalGenerationInferPolicy "
) ,
2023-09-11 17:22:56 +00:00
}
2023-06-19 05:53:17 +00:00
2023-09-11 17:22:56 +00:00
def import_policy ( policy_location : PolicyLocation , inference_only : Optional [ bool ] = False ) - > Policy :
2023-06-19 05:53:17 +00:00
"""
Dynamically import a Policy class based on the policy location .
"""
2023-09-11 17:22:56 +00:00
if inference_only :
module_name = f " colossalai.inference.tensor_parallel.policies. { policy_location . file_name } "
else :
module_name = f " colossalai.shardformer.policies. { policy_location . file_name } "
2023-06-19 05:53:17 +00:00
module = importlib . import_module ( module_name )
return getattr ( module , policy_location . class_name )
2023-06-13 06:44:40 +00:00
2023-06-07 08:09:40 +00:00
2023-06-19 05:53:17 +00:00
def _fullname ( obj ) :
"""
Return the full name of an object , including the module name .
"""
klass = obj . __class__
module = klass . __module__
2023-09-19 06:20:26 +00:00
if module == " builtins " :
return klass . __qualname__ # avoid outputs like 'builtins.str'
return module + " . " + klass . __qualname__
2023-05-22 07:02:17 +00:00
2023-05-24 02:26:46 +00:00
2023-11-03 05:32:43 +00:00
def get_autopolicy ( model : nn . Module , shard_config : ShardConfig = None ) - > Policy :
2023-05-24 02:26:46 +00:00
r """
2023-05-22 07:02:17 +00:00
Return the auto policy for the model
Args :
2023-05-24 02:26:46 +00:00
model ( : class : ` nn . Module ` ) : The model to get the auto policy
2023-05-22 07:02:17 +00:00
Return :
2023-05-24 02:26:46 +00:00
: class : ` Policy ` : The auto policy for the model
2023-05-22 07:02:17 +00:00
"""
2023-06-19 05:53:17 +00:00
full_name = _fullname ( model )
2023-11-19 13:05:05 +00:00
inference_only = shard_config . extra_kwargs . get ( " inference_only " , None )
2023-11-10 02:49:50 +00:00
if inference_only :
2023-09-11 17:22:56 +00:00
policy_location = _INFER_POLICY_LIST . get ( full_name , None )
else :
policy_location = _POLICY_LIST . get ( full_name , None )
2023-06-19 05:53:17 +00:00
if policy_location is None :
2023-05-24 02:26:46 +00:00
raise NotImplementedError (
2023-09-22 03:12:50 +00:00
f " Auto policy for { model . __class__ . __qualname__ } is not implemented \n . Supported models are { list ( _POLICY_LIST . keys ( ) ) } and { list ( _INFER_POLICY_LIST . keys ( ) ) } "
2023-05-24 02:26:46 +00:00
)
2023-06-19 05:53:17 +00:00
else :
2023-11-10 02:49:50 +00:00
policy = import_policy ( policy_location , inference_only )
2023-06-19 05:53:17 +00:00
return policy ( )