ColossalAI/colossalai/shardformer/policies/sam.py

223 lines
8.7 KiB
Python
Raw Normal View History

import colossalai.shardformer.layer as col_nn
[shardformer] update transformers (#5583) * flash_attention forward upgrade * llama_model_forward * remove useless comment * update the requirements.txt * add the transformers version requirements * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update_falcon (#5520) * [shardformer] update mistral model (#5511) * [shardformer] update gpt2 (#5502) * [shardformer] update gptj model (#5503) * [shardformer] update opt (#5522) * [shardformer] update t5 model (#5524) * [shardformer] update whisper model (#5529) * [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * fix conflicts * [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * merge with main * merge with main * llama_model_forward * remove useless comment * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update mistral model (#5511) * [shardformer] update opt (#5522) * [shardformer] update whisper model (#5529) * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt * [shardformer] fix whisper (#5628) * [test] fix llama model test * fix the opt upgrade (#5634) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com>
2024-04-24 14:51:50 +00:00
from ..modeling.sam import forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["SamPolicy", "SamModelPolicy"]
class SamPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.sam.modeling_sam import (
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
SamVisionAttention,
SamVisionLayer,
)
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.out_proj",
target_module=col_nn.Linear1D_Row,
),
],
)
policy[SamTwoWayTransformer] = ModulePolicyDescription(
attribute_replacement={
"final_attn_token_to_image.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
),
],
)
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
policy[SamVisionAttention] = ModulePolicyDescription(
attribute_replacement={
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
},
method_replacement={"forward": forward_fn()},
sub_module_replacement=[],
)
# optimization configuration
# Handle SamVisionLayer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
],
policy=policy,
target_key=SamVisionLayer,
)
# Handle SamTwoWayAttentionBlock
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm3",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm4",
target_module=norm_cls,
),
],
policy=policy,
target_key=SamTwoWayAttentionBlock,
)
# Handle SamTwoWayTransformer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm_final_attn",
target_module=norm_cls,
)
],
policy=policy,
target_key=SamTwoWayTransformer,
)
return policy
def postprocess(self):
return self.model
# SamModel
class SamModelPolicy(SamPolicy):
def __init__(self) -> None:
super().__init__()