2023-08-24 07:50:02 +00:00
|
|
|
import warnings
|
2023-07-25 07:02:29 +00:00
|
|
|
from typing import Callable, Dict, List, Union
|
2023-06-28 05:28:18 +00:00
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
import colossalai.shardformer.layer as col_nn
|
2023-08-07 08:41:07 +00:00
|
|
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
|
2023-06-28 05:28:18 +00:00
|
|
|
|
2023-08-07 08:41:07 +00:00
|
|
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
2023-07-25 07:02:29 +00:00
|
|
|
from ..modeling.vit import (
|
|
|
|
ViTForImageClassification_pipeline_forward,
|
|
|
|
ViTForMaskedImageModeling_pipeline_forward,
|
|
|
|
ViTModel_pipeline_forward,
|
2023-08-07 08:41:07 +00:00
|
|
|
get_jit_fused_vit_output_forward,
|
|
|
|
get_vit_flash_self_attention_forward,
|
2023-07-25 07:02:29 +00:00
|
|
|
)
|
2023-07-05 07:13:00 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-06-28 05:28:18 +00:00
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
class ViTPolicy(Policy):
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def preprocess(self):
|
|
|
|
return self.model
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
policy = {}
|
|
|
|
|
2023-08-24 07:50:02 +00:00
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
|
|
|
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
|
|
|
|
2023-07-25 07:02:29 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="dropout",
|
2023-08-07 08:41:07 +00:00
|
|
|
target_module=DropoutForReplicatedInput,
|
2023-07-25 07:02:29 +00:00
|
|
|
)
|
2023-08-16 02:47:23 +00:00
|
|
|
])
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
|
|
|
"attention.attention.num_attention_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
"attention.attention.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
2023-08-16 02:47:23 +00:00
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
])
|
2023-07-07 06:06:46 +00:00
|
|
|
|
2023-08-07 08:41:07 +00:00
|
|
|
# use flash attention
|
|
|
|
if self.shard_config.enable_flash_attention:
|
2023-08-22 15:59:31 +00:00
|
|
|
self.append_or_create_method_replacement(description={
|
2023-08-07 08:41:07 +00:00
|
|
|
'forward': get_vit_flash_self_attention_forward(),
|
2023-08-22 15:59:31 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTSelfAttention)
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
# use jit fused operator
|
|
|
|
if self.shard_config.enable_jit_fused:
|
2023-08-22 15:59:31 +00:00
|
|
|
self.append_or_create_method_replacement(description={
|
2023-08-07 08:41:07 +00:00
|
|
|
'forward': get_jit_fused_vit_output_forward(),
|
|
|
|
'dropout_add': get_jit_fused_dropout_add_func(),
|
2023-08-22 15:59:31 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=ViTOutput)
|
2023-07-25 07:02:29 +00:00
|
|
|
return policy
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-06-28 05:28:18 +00:00
|
|
|
def new_model_class(self):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
if self.model.__class__.__name__ == 'ViTModel':
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
|
|
|
|
held_layers = []
|
|
|
|
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
held_layers.append(module.embeddings)
|
|
|
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
|
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if self.model.__class__.__name__ == 'ViTModel':
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.vit
|
|
|
|
|
|
|
|
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
|
|
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
|
|
method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement,
|
|
|
|
policy=policy,
|
|
|
|
target_key=model_cls)
|
|
|
|
|
|
|
|
|
|
|
|
# ViTModel
|
|
|
|
class ViTModelPolicy(ViTPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(module.pooler)
|
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
# ViTForImageClassification
|
|
|
|
class ViTForImageClassificationPolicy(ViTPolicy):
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
new_item = {
|
|
|
|
ViTForImageClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
2023-08-07 08:41:07 +00:00
|
|
|
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
2023-07-25 07:02:29 +00:00
|
|
|
])
|
|
|
|
}
|
|
|
|
policy.update(new_item)
|
|
|
|
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
|
|
|
self.set_pipeline_forward(model_cls=ViTForImageClassification,
|
|
|
|
pipeline_forward=ViTForImageClassification_pipeline_forward,
|
|
|
|
policy=policy)
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.classifier)
|
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
# ViTForMaskedImageModeling
|
|
|
|
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def module_policy(self):
|
|
|
|
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
|
|
|
|
|
|
|
policy = super().module_policy()
|
|
|
|
|
|
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
|
|
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
|
|
|
self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling,
|
|
|
|
pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
|
|
|
|
policy=policy)
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
|
|
held_layers = super().get_held_layers()
|
|
|
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
|
|
|
|
|
|
|
module = self.model.vit
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.layernorm)
|
|
|
|
held_layers.append(self.model.decoder)
|
|
|
|
|
|
|
|
return held_layers
|