@ -1,10 +1,8 @@
import importlib
from dataclasses import dataclass
from typing import Optional
import torch . nn as nn
from . . shard . shard_config import ShardConfig
from . base_policy import Policy
__all__ = [ " PolicyLocation " , " get_autopolicy " , " import_policy " ]
@ -150,39 +148,12 @@ _POLICY_LIST = {
) ,
}
_INFER_POLICY_LIST = {
# LlaMa
" transformers.models.llama.modeling_llama.LlamaModel " : PolicyLocation (
file_name = " llama " , class_name = " LlamaModelInferPolicy "
) ,
" transformers.models.llama.modeling_llama.LlamaForCausalLM " : PolicyLocation (
file_name = " llama " , class_name = " LlamaModelInferPolicy "
) ,
# Bloom
" transformers.models.bloom.modeling_bloom.BloomModel " : PolicyLocation (
file_name = " bloom " , class_name = " BloomModelInferPolicy "
) ,
" transformers.models.bloom.modeling_bloom.BloomForCausalLM " : PolicyLocation (
file_name = " bloom " , class_name = " BloomModelInferPolicy "
) ,
# ChatGLM2
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLM2InferPolicy "
) ,
" colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration " : PolicyLocation (
file_name = " chatglm2 " , class_name = " ChatGLM2ForConditionalGenerationInferPolicy "
) ,
}
def import_policy ( policy_location : PolicyLocation , inference_only : Optional [ bool ] = False ) - > Policy :
def import_policy ( policy_location : PolicyLocation ) - > Policy :
"""
Dynamically import a Policy class based on the policy location .
"""
if inference_only :
module_name = f " colossalai.inference.tensor_parallel.policies. { policy_location . file_name } "
else :
module_name = f " colossalai.shardformer.policies. { policy_location . file_name } "
module_name = f " colossalai.shardformer.policies. { policy_location . file_name } "
module = importlib . import_module ( module_name )
return getattr ( module , policy_location . class_name )
@ -198,7 +169,7 @@ def _fullname(obj):
return module + " . " + klass . __qualname__
def get_autopolicy ( model : nn . Module , shard_config : ShardConfig = None ) - > Policy :
def get_autopolicy ( model : nn . Module ) - > Policy :
r """
Return the auto policy for the model
@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
: class : ` Policy ` : The auto policy for the model
"""
full_name = _fullname ( model )
inference_only = shard_config . extra_kwargs . get ( " inference_only " , None )
if inference_only :
policy_location = _INFER_POLICY_LIST . get ( full_name , None )
else :
policy_location = _POLICY_LIST . get ( full_name , None )
policy_location = _POLICY_LIST . get ( full_name , None )
if policy_location is None :
raise NotImplementedError (
f " Auto policy for { model . __class__ . __qualname__ } is not implemented \n . Supported models are { list ( _POLICY_LIST . keys ( ) ) } and { list ( _INFER_POLICY_LIST . keys ( ) ) } "
f " Auto policy for { model . __class__ . __qualname__ } is not implemented \n . Supported models are { list ( _POLICY_LIST . keys ( ) ) } "
)
else :
policy = import_policy ( policy_location , inference_only )
policy = import_policy ( policy_location )
return policy ( )