[shardformer] support bias_gelu_jit_fused for models (#5647)

* support gelu_bias_fused for gpt2

* support gelu_bias_fused for gpt2

fix

fix

fix

* fix

fix

* fix
pull/5541/head
flybird11111 2024-04-29 15:33:51 +08:00 committed by GitHub
parent 7f8b16635b
commit 6af6d6fc9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 115 additions and 2 deletions

View File

@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
return forward
def get_jit_fused_bert_intermediate_forward():
from transformers.models.bert.modeling_bert import BertIntermediate
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.dense(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
return hidden_states
return forward

View File

@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
return hidden_states
return forward
def get_jit_fused_blip2_mlp_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.fc1(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.fc2(hidden_states)
return hidden_states
return forward

View File

@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
return forward
def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states, bias = self.c_fc(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
return forward

View File

@ -372,3 +372,15 @@ def get_jit_fused_vit_output_forward():
return hidden_states
return forward
def get_jit_fused_vit_intermediate_forward():
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.dense(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
return hidden_states
return forward

View File

@ -12,6 +12,7 @@ from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
@ -38,11 +39,13 @@ class BertPolicy(Policy):
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertIntermediate,
BertLayer,
BertModel,
BertOutput,
@ -131,6 +134,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
@ -153,6 +157,14 @@ class BertPolicy(Policy):
),
]
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_intermediate_forward(),
},
policy=policy,
target_key=BertIntermediate,
)
if sp_mode == "split_gather":
self.append_or_create_method_replacement(

View File

@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_mlp_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
@ -18,12 +19,16 @@ class BlipPolicy(Policy):
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = (
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
)
return self.model
def module_policy(self):
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Attention,
Blip2EncoderLayer,
Blip2MLP,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
@ -73,6 +78,7 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
@ -201,6 +207,14 @@ class BlipPolicy(Policy):
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_mlp_forward(),
},
policy=policy,
target_key=Blip2MLP,
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(

View File

@ -10,6 +10,7 @@ from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_jit_fused_gpt2_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
@ -36,10 +37,13 @@ class GPT2Policy(Policy):
"""
self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
self.enable_bias_gelu_fused = (
self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
)
return self.model
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
ATTN_IMPLEMENTATION = {
"eager": GPT2Attention,
@ -119,6 +123,7 @@ class GPT2Policy(Policy):
"n_fused": 1,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
@ -142,6 +147,14 @@ class GPT2Policy(Policy):
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_gpt2_mlp_forward(),
},
policy=policy,
target_key=GPT2MLP,
)
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(

View File

@ -11,6 +11,7 @@ from ..modeling.vit import (
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
get_jit_fused_vit_intermediate_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
)
@ -24,10 +25,17 @@ class ViTPolicy(Policy):
pass
def preprocess(self):
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
from transformers.models.vit.modeling_vit import (
ViTEmbeddings,
ViTIntermediate,
ViTLayer,
ViTOutput,
ViTSelfAttention,
)
policy = {}
@ -83,6 +91,9 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
@ -94,6 +105,14 @@ class ViTPolicy(Policy):
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_vit_intermediate_forward(),
},
policy=policy,
target_key=ViTIntermediate,
)
# use flash attention
if self.shard_config.enable_flash_attention:
@ -115,6 +134,7 @@ class ViTPolicy(Policy):
policy=policy,
target_key=ViTOutput,
)
return policy
def new_model_class(self):