mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] import huggingface implicitly (#4101)
parent
6a88bae4ec
commit
44a190e6ac
|
@ -5,6 +5,8 @@ import torch.nn as nn
|
|||
|
||||
from .basepolicy import Policy
|
||||
|
||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyLocation:
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch.nn as nn
|
|||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
|
||||
|
||||
|
||||
class ParallelModule():
|
||||
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertForMultipleChoice,
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertLayer,
|
||||
BertLMPredictionHead,
|
||||
)
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
||||
'BertForMultipleChoicePolicy'
|
||||
]
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
|
@ -33,6 +31,8 @@ class BertPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||
|
||||
base_policy = {
|
||||
BertLayer:
|
||||
ModulePolicyDescription(
|
||||
|
@ -123,7 +123,7 @@ class BertPolicy(Policy):
|
|||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
|
@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
|
@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
|
@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
|
@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
|
@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
|
||||
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
|
||||
]
|
||||
|
||||
|
||||
class GPT2Policy(Policy):
|
||||
|
||||
|
@ -25,7 +29,9 @@ class GPT2Policy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
base_policy = {
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
return {
|
||||
GPT2Model:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
|
@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel:
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
|
@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
base_policy = {
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
return {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
|
|||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
|
@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
# add a new item for sequence classification
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
from transformers.models.opt.modeling_opt import (
|
||||
OPTAttention,
|
||||
OPTDecoder,
|
||||
OPTDecoderLayer,
|
||||
OPTForCausalLM,
|
||||
OPTForSequenceClassification,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
|
||||
'OPTForQuestionAnsweringPolicy'
|
||||
]
|
||||
|
||||
|
||||
class OPTPolicy(Policy):
|
||||
|
||||
|
@ -29,6 +26,8 @@ class OPTPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
|
||||
base_policy = {
|
||||
OPTDecoder:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
|
@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
|
|||
class OPTForCausalLMPolicy(OPTPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
new_item = {
|
||||
OPTForCausalLM:
|
||||
|
|
|
@ -1,15 +1,4 @@
|
|||
from transformers import T5ForConditionalGeneration
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
base_policy = {
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
return {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
|
|||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
new_item = {
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['ViTPolicy']
|
||||
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
||||
|
@ -25,7 +26,9 @@ class ViTPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
base_policy = {
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
|
||||
return {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
|
|
@ -19,6 +19,7 @@ class ShardConfig:
|
|||
"""
|
||||
tensor_parallel_process_group: int = None
|
||||
enable_fused_normalization: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
|
||||
# TODO: add support for tensor parallel
|
||||
# pipeline_parallel_size: int
|
||||
|
@ -27,6 +28,21 @@ class ShardConfig:
|
|||
# inference_only: bool = True
|
||||
# gather_output: bool = True
|
||||
|
||||
@property
|
||||
def tensor_parallel_size(self):
|
||||
return self._tensor_parallel_size
|
||||
|
||||
def __post_init__(self):
|
||||
# get the parallel size
|
||||
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||
|
||||
# turn on all optimization if all_optimization is set to True
|
||||
if self.enable_all_optimization:
|
||||
self._turn_on_all_optimization()
|
||||
|
||||
def _turn_on_all_optimization(self):
|
||||
"""
|
||||
Turn on all optimization.
|
||||
"""
|
||||
# you can add all the optimization flag here
|
||||
self.fused_layernorm = True
|
||||
|
|
Loading…
Reference in New Issue