2023-06-28 05:28:18 +00:00
|
|
|
from typing import Dict, Union
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
2023-06-28 05:28:18 +00:00
|
|
|
|
|
|
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
|
2023-06-30 02:56:29 +00:00
|
|
|
__all__ = ['ViTPolicy']
|
|
|
|
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
class ViTPolicy(Policy):
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def preprocess(self):
|
|
|
|
# Resize embedding
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
|
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
|
|
|
|
|
|
|
return self.model
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
|
|
|
|
2023-07-03 07:29:11 +00:00
|
|
|
base_policy = {
|
2023-06-28 05:28:18 +00:00
|
|
|
ViTEmbeddings:
|
2023-07-03 07:29:11 +00:00
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
]),
|
2023-06-28 05:28:18 +00:00
|
|
|
ViTLayer:
|
2023-06-28 07:04:35 +00:00
|
|
|
ModulePolicyDescription(attribute_replacement={
|
|
|
|
"attention.attention.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
"attention.attention.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
target_module=DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
]),
|
2023-06-28 05:28:18 +00:00
|
|
|
}
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
# optimization configuration
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
base_policy[ViTAttention].sub_module_replacement.extend([
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layernorm_before",
|
|
|
|
target_module=FusedLayerNorm,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layernorm_after",
|
|
|
|
target_module=FusedLayerNorm,
|
|
|
|
)
|
|
|
|
])
|
|
|
|
base_policy[ViTModel].sub_module_replacement.append(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="layernorm",
|
|
|
|
target_module=FusedLayerNorm,
|
|
|
|
))
|
|
|
|
|
|
|
|
return base_policy
|
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def new_model_class(self):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|