mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
462 lines
19 KiB
462 lines
19 KiB
import colossalai.shardformer.layer as col_nn
|
|
|
|
from ..modeling.blip2 import (
|
|
forward_fn,
|
|
get_blip2_flash_attention_forward,
|
|
get_jit_fused_blip2_mlp_forward,
|
|
get_jit_fused_blip2_QFormer_output_forward,
|
|
get_jit_fused_blip2_QFormer_self_output_forward,
|
|
)
|
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ["BlipPolicy", "BlipModelPolicy"]
|
|
|
|
|
|
class BlipPolicy(Policy):
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
self.tie_weight = self.tie_weight_check()
|
|
self.enable_bias_gelu_fused = (
|
|
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
|
|
)
|
|
return self.model
|
|
|
|
def module_policy(self):
|
|
from transformers.models.blip_2.modeling_blip_2 import (
|
|
Blip2Attention,
|
|
Blip2EncoderLayer,
|
|
Blip2MLP,
|
|
Blip2QFormerLayer,
|
|
Blip2QFormerModel,
|
|
Blip2QFormerOutput,
|
|
Blip2QFormerSelfOutput,
|
|
Blip2VisionModel,
|
|
)
|
|
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
|
|
|
|
policy = {}
|
|
|
|
embedding_cls = None
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
embedding_cls = col_nn.VocabParallelEmbedding1D
|
|
else:
|
|
if self.tie_weight:
|
|
embedding_cls = col_nn.PaddingEmbedding
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
norm_cls = col_nn.FusedLayerNorm
|
|
else:
|
|
norm_cls = col_nn.LayerNorm
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
assert (
|
|
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
|
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
"self_attn.embed_dim": self.model.config.vision_config.hidden_size
|
|
// self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.qkv",
|
|
target_module=col_nn.FusedLinear1D_Col,
|
|
kwargs={
|
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.projection",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.fc1",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.fc2",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
],
|
|
)
|
|
|
|
policy[Blip2QFormerModel] = ModulePolicyDescription(
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
]
|
|
)
|
|
|
|
policy[Blip2QFormerLayer] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
"attention.attention.all_head_size": self.model.config.qformer_config.hidden_size
|
|
// self.shard_config.tensor_parallel_size,
|
|
"crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
"crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size
|
|
// self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.attention.query",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.attention.key",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.attention.value",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.attention.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.output.dense",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.output.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.attention.query",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.attention.key",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.attention.value",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.attention.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.output.dense",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.output.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="intermediate_query.dense",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="output_query.dense",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="output_query.dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
],
|
|
)
|
|
|
|
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"self_attn.embed_dim": self.model.config.text_config.hidden_size
|
|
// self.shard_config.tensor_parallel_size,
|
|
"self_attn.num_heads": self.model.config.text_config.num_attention_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.out_proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="fc1",
|
|
target_module=col_nn.Linear1D_Col,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="fc2",
|
|
target_module=col_nn.Linear1D_Row,
|
|
kwargs={
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
],
|
|
)
|
|
|
|
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
|
if self.enable_bias_gelu_fused:
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_jit_fused_blip2_mlp_forward(),
|
|
},
|
|
policy=policy,
|
|
target_key=Blip2MLP,
|
|
)
|
|
|
|
if embedding_cls is not None:
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="model.decoder.embed_tokens",
|
|
target_module=embedding_cls,
|
|
kwargs=(
|
|
{
|
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
}
|
|
if self.shard_config.enable_tensor_parallelism
|
|
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
|
),
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=OPTForCausalLM,
|
|
)
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=col_nn.VocabParallelLMHead1D,
|
|
kwargs={
|
|
"gather_output": True,
|
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
},
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=OPTForCausalLM,
|
|
)
|
|
else:
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=col_nn.PaddingLMHead,
|
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=OPTForCausalLM,
|
|
)
|
|
# optimization configuration
|
|
# Handle Blip2EncoderLayer layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm1",
|
|
target_module=norm_cls,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm2",
|
|
target_module=norm_cls,
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=Blip2EncoderLayer,
|
|
)
|
|
|
|
# handle Blip2VisionModel layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="post_layernorm",
|
|
target_module=norm_cls,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=Blip2VisionModel,
|
|
)
|
|
|
|
# handle Blip2VisionModel layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="layernorm",
|
|
target_module=norm_cls,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=Blip2QFormerModel,
|
|
)
|
|
|
|
# handle Blip2QFormerLayer layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="attention.output.LayerNorm",
|
|
target_module=norm_cls,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="crossattention.output.LayerNorm",
|
|
target_module=norm_cls,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="output_query.LayerNorm",
|
|
target_module=norm_cls,
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=Blip2QFormerLayer,
|
|
)
|
|
|
|
# handle OPTForCausalLM layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="model.decoder.final_layer_norm",
|
|
target_module=norm_cls,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=OPTForCausalLM,
|
|
)
|
|
|
|
# handle OPTDecoderLayer layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn_layer_norm",
|
|
target_module=norm_cls,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="final_layer_norm",
|
|
target_module=norm_cls,
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=OPTDecoderLayer,
|
|
)
|
|
|
|
# use flash attention
|
|
if self.shard_config.enable_flash_attention:
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_blip2_flash_attention_forward(),
|
|
},
|
|
policy=policy,
|
|
target_key=Blip2Attention,
|
|
)
|
|
|
|
# use jit operator
|
|
if self.shard_config.enable_jit_fused:
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_jit_fused_blip2_QFormer_self_output_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=Blip2QFormerSelfOutput,
|
|
)
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_jit_fused_blip2_QFormer_output_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=Blip2QFormerOutput,
|
|
)
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|
|
|
|
|
|
# Blip2Model
|
|
class Blip2ModelPolicy(BlipPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
|
|
# Blip2ForConditionalGeneration
|
|
class Blip2ForConditionalGenerationPolicy(BlipPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|