mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support bias_gelu_jit_fused for models (#5647)
* support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fixpull/5541/head
parent
7f8b16635b
commit
6af6d6fc9f
|
@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_intermediate_forward():
|
||||
from transformers.models.bert.modeling_bert import BertIntermediate
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, bias = self.dense(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
|
|||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_blip2_mlp_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, bias = self.fc1(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_gpt2_mlp_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||
hidden_states, bias = self.c_fc(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -372,3 +372,15 @@ def get_jit_fused_vit_output_forward():
|
|||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_vit_intermediate_forward():
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, bias = self.dense(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -12,6 +12,7 @@ from ..modeling.bert import (
|
|||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_intermediate_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
)
|
||||
|
@ -38,11 +39,13 @@ class BertPolicy(Policy):
|
|||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
|
@ -131,6 +134,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -153,6 +157,14 @@ class BertPolicy(Policy):
|
|||
),
|
||||
]
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_bert_intermediate_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertIntermediate,
|
||||
)
|
||||
|
||||
if sp_mode == "split_gather":
|
||||
self.append_or_create_method_replacement(
|
||||
|
|
|
@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn
|
|||
from ..modeling.blip2 import (
|
||||
forward_fn,
|
||||
get_blip2_flash_attention_forward,
|
||||
get_jit_fused_blip2_mlp_forward,
|
||||
get_jit_fused_blip2_QFormer_output_forward,
|
||||
get_jit_fused_blip2_QFormer_self_output_forward,
|
||||
)
|
||||
|
@ -18,12 +19,16 @@ class BlipPolicy(Policy):
|
|||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.enable_bias_gelu_fused = (
|
||||
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
|
||||
)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.blip_2.modeling_blip_2 import (
|
||||
Blip2Attention,
|
||||
Blip2EncoderLayer,
|
||||
Blip2MLP,
|
||||
Blip2QFormerLayer,
|
||||
Blip2QFormerModel,
|
||||
Blip2QFormerOutput,
|
||||
|
@ -73,6 +78,7 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
|
@ -201,6 +207,14 @@ class BlipPolicy(Policy):
|
|||
)
|
||||
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_blip2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2MLP,
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
|
|
@ -10,6 +10,7 @@ from ..modeling.gpt2 import (
|
|||
GPT2PipelineForwards,
|
||||
get_gpt2_flash_attention_forward,
|
||||
get_gpt_model_forward_for_flash_attn,
|
||||
get_jit_fused_gpt2_mlp_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
gpt2_sequence_parallel_forward_fn,
|
||||
)
|
||||
|
@ -36,10 +37,13 @@ class GPT2Policy(Policy):
|
|||
"""
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
self.enable_bias_gelu_fused = (
|
||||
self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
|
||||
)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": GPT2Attention,
|
||||
|
@ -119,6 +123,7 @@ class GPT2Policy(Policy):
|
|||
"n_fused": 1,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -142,6 +147,14 @@ class GPT2Policy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_gpt2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPT2MLP,
|
||||
)
|
||||
if embedding_cls is not None:
|
||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..modeling.vit import (
|
|||
ViTForImageClassification_pipeline_forward,
|
||||
ViTForMaskedImageModeling_pipeline_forward,
|
||||
ViTModel_pipeline_forward,
|
||||
get_jit_fused_vit_intermediate_forward,
|
||||
get_jit_fused_vit_output_forward,
|
||||
get_vit_flash_self_attention_forward,
|
||||
)
|
||||
|
@ -24,10 +25,17 @@ class ViTPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
|
||||
from transformers.models.vit.modeling_vit import (
|
||||
ViTEmbeddings,
|
||||
ViTIntermediate,
|
||||
ViTLayer,
|
||||
ViTOutput,
|
||||
ViTSelfAttention,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -83,6 +91,9 @@ class ViTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
|
@ -94,6 +105,14 @@ class ViTPolicy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_vit_intermediate_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=ViTIntermediate,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -115,6 +134,7 @@ class ViTPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=ViTOutput,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def new_model_class(self):
|
||||
|
|
Loading…
Reference in New Issue