mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 4441 (#4445)
Co-authored-by: github-actions <github-actions@github.com>pull/4426/merge
parent
5d4efdf58f
commit
d20dceb9a3
|
@ -40,7 +40,7 @@ class ViTPolicy(Policy):
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=DropoutForReplicatedInput,
|
target_module=DropoutForReplicatedInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
|
|
||||||
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attention.attention.num_attention_heads":
|
"attention.attention.num_attention_heads":
|
||||||
|
@ -48,45 +48,45 @@ class ViTPolicy(Policy):
|
||||||
"attention.attention.all_head_size":
|
"attention.attention.all_head_size":
|
||||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.query",
|
suffix="attention.attention.query",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.key",
|
suffix="attention.attention.key",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.value",
|
suffix="attention.attention.value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.dropout",
|
suffix="attention.attention.dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dropout",
|
suffix="attention.output.dropout",
|
||||||
target_module=col_nn.DropoutForReplicatedInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dropout",
|
suffix="output.dropout",
|
||||||
target_module=col_nn.DropoutForReplicatedInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
|
|
@ -21,7 +21,7 @@ def check_stage_manager():
|
||||||
1: [0, 1],
|
1: [0, 1],
|
||||||
2: [2, 3],
|
2: [2, 3],
|
||||||
3: [2, 3],
|
3: [2, 3],
|
||||||
}
|
}
|
||||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
Loading…
Reference in New Issue