2023-07-14 07:56:59 +00:00
|
|
|
import colossalai.shardformer.layer as col_nn
|
|
|
|
|
2024-04-24 14:51:50 +00:00
|
|
|
from ..modeling.sam import forward_fn
|
2023-08-01 10:02:49 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-07-14 07:56:59 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["SamPolicy", "SamModelPolicy"]
|
2023-07-14 07:56:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SamPolicy(Policy):
|
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def preprocess(self):
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.sam.modeling_sam import (
|
|
|
|
SamTwoWayAttentionBlock,
|
|
|
|
SamTwoWayTransformer,
|
|
|
|
SamVisionAttention,
|
|
|
|
SamVisionLayer,
|
|
|
|
)
|
|
|
|
|
|
|
|
policy = {}
|
|
|
|
|
2023-11-03 05:32:43 +00:00
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
norm_cls = col_nn.FusedLayerNorm
|
|
|
|
else:
|
|
|
|
norm_cls = col_nn.LayerNorm
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
2023-11-03 05:32:43 +00:00
|
|
|
|
2023-07-14 07:56:59 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
2024-04-29 10:47:47 +00:00
|
|
|
assert (
|
|
|
|
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
2023-09-19 06:20:26 +00:00
|
|
|
policy[SamVisionLayer] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attn.qkv",
|
|
|
|
target_module=col_nn.FusedLinear1D_Col,
|
|
|
|
kwargs={
|
2024-10-10 06:34:45 +00:00
|
|
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
2024-08-12 10:17:05 +00:00
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2023-09-19 06:20:26 +00:00
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attn.proj",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin1",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin2",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
2023-09-19 06:20:26 +00:00
|
|
|
"self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
2023-07-14 07:56:59 +00:00
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.out_proj",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.q_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.k_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.v_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.out_proj",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin1",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin2",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.q_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.k_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.v_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.out_proj",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-07-14 07:56:59 +00:00
|
|
|
),
|
2023-09-19 06:20:26 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
policy[SamTwoWayTransformer] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"final_attn_token_to_image.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.q_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.k_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.v_proj",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.out_proj",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
|
|
|
|
policy[SamVisionAttention] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
|
|
|
|
},
|
|
|
|
method_replacement={"forward": forward_fn()},
|
|
|
|
sub_module_replacement=[],
|
|
|
|
)
|
|
|
|
elif use_zbv:
|
|
|
|
policy[SamVisionLayer] = ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attn.qkv",
|
|
|
|
target_module=col_nn.FusedLinear,
|
|
|
|
kwargs={
|
|
|
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attn.proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin1",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin2",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.out_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.q_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.k_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.v_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_token_to_image.out_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin1",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.lin2",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.q_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.k_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.v_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="cross_attn_image_to_token.out_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
policy[SamTwoWayTransformer] = ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.q_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.k_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.v_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="final_attn_token_to_image.out_proj",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
|
|
|
|
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
|
2023-09-19 06:20:26 +00:00
|
|
|
policy[SamVisionAttention] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
|
|
|
|
},
|
|
|
|
method_replacement={"forward": forward_fn()},
|
|
|
|
sub_module_replacement=[],
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
|
|
|
|
# optimization configuration
|
2023-11-03 05:32:43 +00:00
|
|
|
# Handle SamVisionLayer
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm1",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm2",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=SamVisionLayer,
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
|
2023-11-03 05:32:43 +00:00
|
|
|
# Handle SamTwoWayAttentionBlock
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm1",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm2",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm3",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm4",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=SamTwoWayAttentionBlock,
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
|
2023-11-03 05:32:43 +00:00
|
|
|
# Handle SamTwoWayTransformer
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layer_norm_final_attn",
|
|
|
|
target_module=norm_cls,
|
|
|
|
)
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=SamTwoWayTransformer,
|
|
|
|
)
|
2023-07-14 07:56:59 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
# SamModel
|
|
|
|
class SamModelPolicy(SamPolicy):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|