mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
11 KiB
220 lines
11 KiB
import torch.nn as nn
|
|
|
|
import colossalai.shardformer.layer as col_nn
|
|
|
|
from .._utils import getattr_, setattr_
|
|
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ['SamPolicy', 'SamModelPolicy']
|
|
|
|
|
|
class SamPolicy(Policy):
|
|
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
return self.model
|
|
|
|
def module_policy(self):
|
|
from transformers.models.sam.modeling_sam import (
|
|
SamAttention,
|
|
SamFeedForward,
|
|
SamTwoWayAttentionBlock,
|
|
SamTwoWayTransformer,
|
|
SamVisionAttention,
|
|
SamVisionLayer,
|
|
)
|
|
|
|
policy = {}
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={
|
|
"attn.num_attention_heads":
|
|
self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="attn.qkv",
|
|
target_module=col_nn.FusedLinear1D_Col,
|
|
kwargs={
|
|
"n_fused": 3,
|
|
},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="attn.proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.lin1",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.lin2",
|
|
target_module=col_nn.Linear1D_Row,
|
|
)
|
|
])
|
|
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"self_attn.num_attention_heads":
|
|
self.model.config.mask_decoder_config.num_attention_heads //
|
|
self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.out_proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_token_to_image.q_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_token_to_image.k_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_token_to_image.v_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_token_to_image.out_proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.lin1",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.lin2",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_image_to_token.q_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_image_to_token.k_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_image_to_token.v_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="cross_attn_image_to_token.out_proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
])
|
|
policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={
|
|
"final_attn_token_to_image.num_attention_heads":
|
|
self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="final_attn_token_to_image.q_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="final_attn_token_to_image.k_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="final_attn_token_to_image.v_proj",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="final_attn_token_to_image.out_proj",
|
|
target_module=col_nn.Linear1D_Row,
|
|
)
|
|
])
|
|
|
|
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
|
|
policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={
|
|
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
|
|
},
|
|
method_replacement={"forward": forward_fn()},
|
|
sub_module_replacement=[])
|
|
|
|
# optimization configuration
|
|
if self.shard_config.enable_fused_normalization:
|
|
# Handle SamVisionLayer
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm1",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm2",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=SamVisionLayer)
|
|
|
|
# Handle SamTwoWayAttentionBlock
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm1",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm2",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm3",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm4",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=SamTwoWayAttentionBlock)
|
|
|
|
# Handle SamTwoWayTransformer
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="layer_norm_final_attn",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=SamTwoWayTransformer)
|
|
|
|
# use flash attention
|
|
if self.shard_config.enable_flash_attention:
|
|
policy[SamAttention] = ModulePolicyDescription(method_replacement={
|
|
'forward': get_sam_flash_attention_forward(),
|
|
})
|
|
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
|
|
'forward': get_sam_vision_flash_attention_forward(),
|
|
})
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|
|
|
|
|
|
# SamModel
|
|
class SamModelPolicy(SamPolicy):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|