|
|
@ -1,4 +1,3 @@
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
from typing import Callable, Dict, List, Union
|
|
|
|
from typing import Callable, Dict, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
@ -36,7 +35,7 @@ class ViTPolicy(Policy):
|
|
|
|
suffix="dropout",
|
|
|
|
suffix="dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
])
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
|
|
|
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
|
|
|
"attention.attention.num_attention_heads":
|
|
|
|
"attention.attention.num_attention_heads":
|
|
|
@ -44,45 +43,47 @@ class ViTPolicy(Policy):
|
|
|
|
"attention.attention.all_head_size":
|
|
|
|
"attention.attention.all_head_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
param_replacement=[],
|
|
|
|
param_replacement=[],
|
|
|
|
sub_module_replacement=[
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
suffix="attention.attention.query",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
suffix="attention.attention.key",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
suffix="attention.attention.value",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
suffix="attention.attention.dropout",
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
target_module=col_nn.Linear1D_Col,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dense",
|
|
|
|
suffix="output.dense",
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
target_module=col_nn.Linear1D_Row,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="output.dropout",
|
|
|
|
suffix="output.dropout",
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
])
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return policy
|
|
|
|
|
|
|
|
|
|
|
|
return policy
|
|
|
|
return policy
|
|
|
|
|
|
|
|
|
|
|
|