|
|
@ -75,6 +75,8 @@ class BertPolicy(Policy):
|
|
|
|
|
|
|
|
|
|
|
|
sp_partial_derived = sp_mode == "split_gather"
|
|
|
|
sp_partial_derived = sp_mode == "split_gather"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
|
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
assert (
|
|
|
|
assert (
|
|
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
|
@ -97,6 +99,7 @@ class BertPolicy(Policy):
|
|
|
|
kwargs={
|
|
|
|
kwargs={
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
@ -105,6 +108,7 @@ class BertPolicy(Policy):
|
|
|
|
kwargs={
|
|
|
|
kwargs={
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
@ -113,6 +117,7 @@ class BertPolicy(Policy):
|
|
|
|
kwargs={
|
|
|
|
kwargs={
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
@ -125,6 +130,7 @@ class BertPolicy(Policy):
|
|
|
|
kwargs={
|
|
|
|
kwargs={
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
@ -138,6 +144,7 @@ class BertPolicy(Policy):
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
@ -146,6 +153,97 @@ class BertPolicy(Policy):
|
|
|
|
kwargs={
|
|
|
|
kwargs={
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="output.dropout",
|
|
|
|
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
policy[BertEmbeddings] = ModulePolicyDescription(
|
|
|
|
|
|
|
|
sub_module_replacement=[
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="dropout",
|
|
|
|
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.enable_bias_gelu_fused:
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
|
|
|
|
description={
|
|
|
|
|
|
|
|
"forward": get_jit_fused_bert_intermediate_forward(),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
policy=policy,
|
|
|
|
|
|
|
|
target_key=BertIntermediate,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif use_zbv:
|
|
|
|
|
|
|
|
policy[BertLayer] = ModulePolicyDescription(
|
|
|
|
|
|
|
|
sub_module_replacement=[
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.self.query",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.self.key",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.self.value",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.self.dropout",
|
|
|
|
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.output.dense",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="attention.output.dropout",
|
|
|
|
|
|
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="intermediate.dense",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
|
|
|
suffix="output.dense",
|
|
|
|
|
|
|
|
target_module=col_nn.LinearWithGradAccum,
|
|
|
|
|
|
|
|
kwargs={
|
|
|
|
|
|
|
|
"seq_parallel_mode": sp_mode,
|
|
|
|
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication,
|
|
|
|
|
|
|
|
"use_zbv": use_zbv,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|