2023-08-24 07:50:02 +00:00
|
|
|
import warnings
|
2023-07-25 07:02:29 +00:00
|
|
|
from typing import Callable, Dict, List, Union
|
2023-06-28 05:28:18 +00:00
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
import colossalai.shardformer.layer as col_nn
|
2023-08-07 08:41:07 +00:00
|
|
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
|
2023-06-28 05:28:18 +00:00
|
|
|
|
2023-08-07 08:41:07 +00:00
|
|
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
2023-07-25 07:02:29 +00:00
|
|
|
from ..modeling.vit import (
|
|
|
|
ViTForImageClassification_pipeline_forward,
|
|
|
|
ViTForMaskedImageModeling_pipeline_forward,
|
|
|
|
ViTModel_pipeline_forward,
|
2024-04-29 07:33:51 +00:00
|
|
|
get_jit_fused_vit_intermediate_forward,
|
2023-08-07 08:41:07 +00:00
|
|
|
get_jit_fused_vit_output_forward,
|
|
|
|
get_vit_flash_self_attention_forward,
|
2023-07-25 07:02:29 +00:00
|
|
|
)
|
2023-07-05 07:13:00 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-06-28 05:28:18 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["ViTPolicy", "ViTModelPolicy", "ViTForImageClassificationPolicy", "ViTForMaskedImageModelingPolicy"]
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
class ViTPolicy(Policy):
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def preprocess(self):
|
2024-04-29 07:33:51 +00:00
|
|
|
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
2023-06-28 05:28:18 +00:00
|
|
|
return self.model
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
2024-04-29 07:33:51 +00:00
|
|
|
from transformers.models.vit.modeling_vit import (
|
|
|
|
ViTEmbeddings,
|
|
|
|
ViTIntermediate,
|
|
|
|
ViTLayer,
|
|
|
|
ViTOutput,
|
|
|
|
ViTSelfAttention,
|
|
|
|
)
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
policy = {}
|
|
|
|
|
2023-08-24 07:50:02 +00:00
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
2024-01-30 01:57:38 +00:00
|
|
|
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
2023-08-24 07:50:02 +00:00
|
|
|
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
2024-04-29 10:47:47 +00:00
|
|
|
assert (
|
|
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
2023-09-19 06:20:26 +00:00
|
|
|
policy[ViTEmbeddings] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={},
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
policy[ViTLayer] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"attention.attention.num_attention_heads": self.model.config.num_attention_heads
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
"attention.attention.all_head_size": self.model.config.hidden_size
|
|
|
|
// self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
2024-04-29 07:33:51 +00:00
|
|
|
kwargs={
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
2024-08-12 10:17:05 +00:00
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-04-29 07:33:51 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
2024-08-12 10:17:05 +00:00
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
"use_zbv": use_zbv,
|
2024-08-12 10:17:05 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
2024-04-29 07:33:51 +00:00
|
|
|
if self.enable_bias_gelu_fused:
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
|
|
|
"forward": get_jit_fused_vit_intermediate_forward(),
|
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTIntermediate,
|
|
|
|
)
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
elif use_zbv:
|
|
|
|
policy[ViTEmbeddings] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={},
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
],
|
|
|
|
)
|
2023-07-07 06:06:46 +00:00
|
|
|
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
policy[ViTLayer] = ModulePolicyDescription(
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs={
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
if self.enable_bias_gelu_fused:
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
|
|
|
"forward": get_jit_fused_vit_intermediate_forward(),
|
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTIntermediate,
|
|
|
|
)
|
2023-08-07 08:41:07 +00:00
|
|
|
# use flash attention
|
|
|
|
if self.shard_config.enable_flash_attention:
|
2023-09-19 06:20:26 +00:00
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
|
|
|
"forward": get_vit_flash_self_attention_forward(),
|
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTSelfAttention,
|
|
|
|
)
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
# use jit fused operator
|
|
|
|
if self.shard_config.enable_jit_fused:
|
2023-09-19 06:20:26 +00:00
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
|
|
|
"forward": get_jit_fused_vit_output_forward(),
|
|
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTOutput,
|
|
|
|
)
|
2024-04-29 07:33:51 +00:00
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
return policy
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def new_model_class(self):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if self.model.__class__.__name__ == "ViTModel":
|
2023-07-25 07:02:29 +00:00
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
|
|
|
|
held_layers = []
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
if stage_manager.is_interleave:
|
|
|
|
assert stage_manager.num_model_chunks is not None
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
|
|
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
|
|
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
held_layers.append(module.embeddings)
|
|
|
|
for start_idx, end_idx in stage_indices:
|
|
|
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
|
|
|
else:
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
held_layers.append(module.embeddings)
|
|
|
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
|
|
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
2023-07-25 07:02:29 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
2023-09-19 06:20:26 +00:00
|
|
|
if self.model.__class__.__name__ == "ViTModel":
|
2023-07-25 07:02:29 +00:00
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.vit
|
|
|
|
|
2024-04-01 03:34:58 +00:00
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
|
|
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
2023-09-19 06:20:26 +00:00
|
|
|
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description=method_replacement, policy=policy, target_key=model_cls
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
# ViTModel
|
|
|
|
class ViTModelPolicy(ViTPolicy):
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
if stage_manager.is_interleave:
|
|
|
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
|
|
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
|
|
|
):
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(module.pooler)
|
|
|
|
else:
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(module.pooler)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
# ViTForImageClassification
|
|
|
|
class ViTForImageClassificationPolicy(ViTPolicy):
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
|
|
|
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
new_item = {
|
2023-09-19 06:20:26 +00:00
|
|
|
ViTForImageClassification: ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
2023-07-25 07:02:29 +00:00
|
|
|
SubModuleReplacementDescription(
|
2024-08-12 10:17:05 +00:00
|
|
|
suffix="classifier",
|
|
|
|
target_module=Linear1D_Col,
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
kwargs=dict(
|
|
|
|
gather_output=True,
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
use_zbv=use_zbv,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
}
|
|
|
|
policy.update(new_item)
|
|
|
|
elif use_zbv:
|
|
|
|
new_item = {
|
|
|
|
ViTForImageClassification: ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="classifier",
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
kwargs=dict(
|
|
|
|
gather_output=True,
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
use_zbv=use_zbv,
|
|
|
|
),
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
|
|
|
]
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
}
|
|
|
|
policy.update(new_item)
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
2023-09-19 06:20:26 +00:00
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls=ViTForImageClassification,
|
|
|
|
pipeline_forward=ViTForImageClassification_pipeline_forward,
|
|
|
|
policy=policy,
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
if stage_manager.is_interleave:
|
|
|
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
|
|
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
|
|
|
):
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.classifier)
|
|
|
|
else:
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.classifier)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
# ViTForMaskedImageModeling
|
|
|
|
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
2023-09-19 06:20:26 +00:00
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls=ViTForMaskedImageModeling,
|
|
|
|
pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
|
|
|
|
policy=policy,
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv
* [feat] support chatglm2, command, deepseek for zbv
* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper
* [feat] support GPT2FusedLinearConv1D
* [feat] support GPT2FusedLinear (without tp)
* [fix] debug FusedConvLinear
* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.
* [Shardformer] support FusedLinear1D base for zbv
* [shardformer] support zbv in FusedLinear1D base, Col, Row
* [shardformer] support zbv in blip2 and sam policy
* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;
* [fix] fix incorrect number of gradients ;
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [Shardformer] add en doc for zbv;
* [fix] fix typo in Model compatibility table
* [fix] fix API Reference typo
* [Shardformer] add zh-Han doc for zbv
* [fix] fix Linear name; update en & zh doc
* [fix] fix shardformer doc import err
* [fix] fix shardconfig import in doc
* [fix] fix shardformer doc
* [fix] fix shardconfig doc
* [fix] fix config
* [fix] remove shardconfig
* [fix] fix doc
* [feat] add zbv doc string
* [fix] rm doc
* [fix] fix doc
* [fix] empty zbv doc
* [fix] ifx torch version
* [fix] fix torch version
* [fix] fix torch versions
* [fix] fix torch versions
* [fix] fix pyramid versions
* [fix] fix pyramid, zope version
* [fix] try fix workflow
* [fix] try import ShardConfig in yml
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix workflow
* [fix] fix ci
* [fix] fix zbv doc
* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;
* [fix] fix policy use fused_linear
* [fix] fix weight grad none, err caused by weight ptr change
* [fix] fix comm in WeightGradStore
* [fix] fix WeightGradStore pop param
* [fix] remove useless param in doc; fix gpt2 qkv test;
* [shardformer] simplify execute_w_pass_grad_accum;
* [fix] rm useless comments
* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
* [shardformer] Run meaningful doc test
* [shadformer] fix doc test cmd;
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 02:22:26 +00:00
|
|
|
if stage_manager.is_interleave:
|
|
|
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
|
|
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
|
|
|
):
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.decoder)
|
|
|
|
else:
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.decoder)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
return held_layers
|