[shardformer] import huggingface implicitly (#4101)

pull/4157/head
Frank Lee 2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@ -5,6 +5,8 @@ import torch.nn as nn
from .basepolicy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@dataclass
class PolicyLocation:

View File

@ -8,6 +8,8 @@ import torch.nn as nn
from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
class ParallelModule():

View File

@ -1,18 +1,16 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy'
]
class BertPolicy(Policy):
@ -33,6 +31,8 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
base_policy = {
BertLayer:
ModulePolicyDescription(
@ -123,7 +123,7 @@ class BertPolicy(Policy):
def new_model_class(self):
# do nothing
return self.model
return None
def postprocess(self):
return self.model
@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:

View File

@ -1,11 +1,15 @@
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
]
class GPT2Policy(Policy):
@ -25,7 +29,9 @@ class GPT2Policy(Policy):
return self.model
def module_policy(self):
base_policy = {
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
return {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:

View File

@ -1,13 +1,13 @@
from typing import Dict, Union
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
class LlamaPolicy(Policy):
@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
return {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
class LlamaForCausalLMPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
# add a new item for sequence classification

View File

@ -1,15 +1,12 @@
from transformers.models.opt.modeling_opt import (
OPTAttention,
OPTDecoder,
OPTDecoderLayer,
OPTForCausalLM,
OPTForSequenceClassification,
)
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
'OPTForQuestionAnsweringPolicy'
]
class OPTPolicy(Policy):
@ -29,6 +26,8 @@ class OPTPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
class OPTForCausalLMPolicy(OPTPolicy):
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
new_item = {
OPTForCausalLM:

View File

@ -1,15 +1,4 @@
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
return self.model
def module_policy(self):
base_policy = {
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
return {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
def module_policy(self):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
new_item = {

View File

@ -1,12 +1,13 @@
from typing import Dict, Union
import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ViTPolicy']
class ViTPolicy(Policy):
@ -25,7 +26,9 @@ class ViTPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
return {
ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],

View File

@ -19,6 +19,7 @@ class ShardConfig:
"""
tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
@ -27,6 +28,21 @@ class ShardConfig:
# inference_only: bool = True
# gather_output: bool = True
@property
def tensor_parallel_size(self):
return self._tensor_parallel_size
def __post_init__(self):
# get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()
def _turn_on_all_optimization(self):
"""
Turn on all optimization.
"""
# you can add all the optimization flag here
self.fused_layernorm = True