mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support bias_gelu_jit_fused for models (#5647)
* support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fixpull/5541/head
parent
7f8b16635b
commit
6af6d6fc9f
|
@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_jit_fused_bert_intermediate_forward():
|
||||||
|
from transformers.models.bert.modeling_bert import BertIntermediate
|
||||||
|
|
||||||
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||||
|
|
||||||
|
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, bias = self.dense(hidden_states)
|
||||||
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_jit_fused_blip2_mlp_forward():
|
||||||
|
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
|
||||||
|
|
||||||
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||||
|
|
||||||
|
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, bias = self.fc1(hidden_states)
|
||||||
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_jit_fused_gpt2_mlp_forward():
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||||
|
|
||||||
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||||
|
|
||||||
|
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||||
|
hidden_states, bias = self.c_fc(hidden_states)
|
||||||
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||||
|
hidden_states = self.c_proj(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -372,3 +372,15 @@ def get_jit_fused_vit_output_forward():
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_jit_fused_vit_intermediate_forward():
|
||||||
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, bias = self.dense(hidden_states)
|
||||||
|
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -12,6 +12,7 @@ from ..modeling.bert import (
|
||||||
BertPipelineForwards,
|
BertPipelineForwards,
|
||||||
bert_sequence_parallel_forward_fn,
|
bert_sequence_parallel_forward_fn,
|
||||||
get_bert_flash_attention_forward,
|
get_bert_flash_attention_forward,
|
||||||
|
get_jit_fused_bert_intermediate_forward,
|
||||||
get_jit_fused_bert_output_forward,
|
get_jit_fused_bert_output_forward,
|
||||||
get_jit_fused_bert_self_output_forward,
|
get_jit_fused_bert_self_output_forward,
|
||||||
)
|
)
|
||||||
|
@ -38,11 +39,13 @@ class BertPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
self.tie_weight = self.tie_weight_check()
|
self.tie_weight = self.tie_weight_check()
|
||||||
|
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import (
|
from transformers.models.bert.modeling_bert import (
|
||||||
BertEmbeddings,
|
BertEmbeddings,
|
||||||
|
BertIntermediate,
|
||||||
BertLayer,
|
BertLayer,
|
||||||
BertModel,
|
BertModel,
|
||||||
BertOutput,
|
BertOutput,
|
||||||
|
@ -131,6 +134,7 @@ class BertPolicy(Policy):
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -153,6 +157,14 @@ class BertPolicy(Policy):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if self.enable_bias_gelu_fused:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_jit_fused_bert_intermediate_forward(),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BertIntermediate,
|
||||||
|
)
|
||||||
|
|
||||||
if sp_mode == "split_gather":
|
if sp_mode == "split_gather":
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
|
|
|
@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn
|
||||||
from ..modeling.blip2 import (
|
from ..modeling.blip2 import (
|
||||||
forward_fn,
|
forward_fn,
|
||||||
get_blip2_flash_attention_forward,
|
get_blip2_flash_attention_forward,
|
||||||
|
get_jit_fused_blip2_mlp_forward,
|
||||||
get_jit_fused_blip2_QFormer_output_forward,
|
get_jit_fused_blip2_QFormer_output_forward,
|
||||||
get_jit_fused_blip2_QFormer_self_output_forward,
|
get_jit_fused_blip2_QFormer_self_output_forward,
|
||||||
)
|
)
|
||||||
|
@ -18,12 +19,16 @@ class BlipPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
self.tie_weight = self.tie_weight_check()
|
self.tie_weight = self.tie_weight_check()
|
||||||
|
self.enable_bias_gelu_fused = (
|
||||||
|
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
|
||||||
|
)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.blip_2.modeling_blip_2 import (
|
from transformers.models.blip_2.modeling_blip_2 import (
|
||||||
Blip2Attention,
|
Blip2Attention,
|
||||||
Blip2EncoderLayer,
|
Blip2EncoderLayer,
|
||||||
|
Blip2MLP,
|
||||||
Blip2QFormerLayer,
|
Blip2QFormerLayer,
|
||||||
Blip2QFormerModel,
|
Blip2QFormerModel,
|
||||||
Blip2QFormerOutput,
|
Blip2QFormerOutput,
|
||||||
|
@ -73,6 +78,7 @@ class BlipPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.fc1",
|
suffix="mlp.fc1",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.fc2",
|
suffix="mlp.fc2",
|
||||||
|
@ -201,6 +207,14 @@ class BlipPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||||
|
if self.enable_bias_gelu_fused:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_jit_fused_blip2_mlp_forward(),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=Blip2MLP,
|
||||||
|
)
|
||||||
|
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ..modeling.gpt2 import (
|
||||||
GPT2PipelineForwards,
|
GPT2PipelineForwards,
|
||||||
get_gpt2_flash_attention_forward,
|
get_gpt2_flash_attention_forward,
|
||||||
get_gpt_model_forward_for_flash_attn,
|
get_gpt_model_forward_for_flash_attn,
|
||||||
|
get_jit_fused_gpt2_mlp_forward,
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
gpt2_sequence_parallel_forward_fn,
|
gpt2_sequence_parallel_forward_fn,
|
||||||
)
|
)
|
||||||
|
@ -36,10 +37,13 @@ class GPT2Policy(Policy):
|
||||||
"""
|
"""
|
||||||
self.tie_weight = self.tie_weight_check()
|
self.tie_weight = self.tie_weight_check()
|
||||||
self.origin_attn_implement = self.model.config._attn_implementation
|
self.origin_attn_implement = self.model.config._attn_implementation
|
||||||
|
self.enable_bias_gelu_fused = (
|
||||||
|
self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
|
||||||
|
)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
ATTN_IMPLEMENTATION = {
|
||||||
"eager": GPT2Attention,
|
"eager": GPT2Attention,
|
||||||
|
@ -119,6 +123,7 @@ class GPT2Policy(Policy):
|
||||||
"n_fused": 1,
|
"n_fused": 1,
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -142,6 +147,14 @@ class GPT2Policy(Policy):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if self.enable_bias_gelu_fused:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_jit_fused_gpt2_mlp_forward(),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=GPT2MLP,
|
||||||
|
)
|
||||||
if embedding_cls is not None:
|
if embedding_cls is not None:
|
||||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..modeling.vit import (
|
||||||
ViTForImageClassification_pipeline_forward,
|
ViTForImageClassification_pipeline_forward,
|
||||||
ViTForMaskedImageModeling_pipeline_forward,
|
ViTForMaskedImageModeling_pipeline_forward,
|
||||||
ViTModel_pipeline_forward,
|
ViTModel_pipeline_forward,
|
||||||
|
get_jit_fused_vit_intermediate_forward,
|
||||||
get_jit_fused_vit_output_forward,
|
get_jit_fused_vit_output_forward,
|
||||||
get_vit_flash_self_attention_forward,
|
get_vit_flash_self_attention_forward,
|
||||||
)
|
)
|
||||||
|
@ -24,10 +25,17 @@ class ViTPolicy(Policy):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
|
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
|
from transformers.models.vit.modeling_vit import (
|
||||||
|
ViTEmbeddings,
|
||||||
|
ViTIntermediate,
|
||||||
|
ViTLayer,
|
||||||
|
ViTOutput,
|
||||||
|
ViTSelfAttention,
|
||||||
|
)
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
@ -83,6 +91,9 @@ class ViTPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={
|
||||||
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
|
@ -94,6 +105,14 @@ class ViTPolicy(Policy):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if self.enable_bias_gelu_fused:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_jit_fused_vit_intermediate_forward(),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=ViTIntermediate,
|
||||||
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -115,6 +134,7 @@ class ViTPolicy(Policy):
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=ViTOutput,
|
target_key=ViTOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
|
|
Loading…
Reference in New Issue