Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout
pull/4445/head
Kun Lin 1 year ago committed by Hongxin Liu
parent 0ceec8f9a9
commit c59d7aca09

@ -1,4 +1,3 @@
from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
@ -36,7 +35,7 @@ class ViTPolicy(Policy):
suffix="dropout", suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
) )
]) ])
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads": "attention.attention.num_attention_heads":
@ -44,45 +43,47 @@ class ViTPolicy(Policy):
"attention.attention.all_head_size": "attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.query", suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.key", suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.value", suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.attention.dropout", suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
), ),
]) ])
return policy
return policy return policy

Loading…
Cancel
Save