|
|
|
@ -75,6 +75,8 @@ class BertPolicy(Policy):
|
|
|
|
|
|
|
|
|
|
sp_partial_derived = sp_mode == "split_gather" |
|
|
|
|
|
|
|
|
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv |
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism: |
|
|
|
|
assert ( |
|
|
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 |
|
|
|
@ -97,6 +99,7 @@ class BertPolicy(Policy):
|
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
@ -105,6 +108,7 @@ class BertPolicy(Policy):
|
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
@ -113,6 +117,7 @@ class BertPolicy(Policy):
|
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
@ -125,6 +130,7 @@ class BertPolicy(Policy):
|
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
@ -138,6 +144,7 @@ class BertPolicy(Policy):
|
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
@ -146,6 +153,97 @@ class BertPolicy(Policy):
|
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="output.dropout", |
|
|
|
|
target_module=col_nn.DropoutForParallelInput, |
|
|
|
|
), |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
policy[BertEmbeddings] = ModulePolicyDescription( |
|
|
|
|
sub_module_replacement=[ |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="dropout", |
|
|
|
|
target_module=col_nn.DropoutForReplicatedInput, |
|
|
|
|
), |
|
|
|
|
] |
|
|
|
|
) |
|
|
|
|
if self.enable_bias_gelu_fused: |
|
|
|
|
self.append_or_create_method_replacement( |
|
|
|
|
description={ |
|
|
|
|
"forward": get_jit_fused_bert_intermediate_forward(), |
|
|
|
|
}, |
|
|
|
|
policy=policy, |
|
|
|
|
target_key=BertIntermediate, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
elif use_zbv: |
|
|
|
|
policy[BertLayer] = ModulePolicyDescription( |
|
|
|
|
sub_module_replacement=[ |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.self.query", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.self.key", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.self.value", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.self.dropout", |
|
|
|
|
target_module=col_nn.DropoutForParallelInput, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.output.dense", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="attention.output.dropout", |
|
|
|
|
target_module=col_nn.DropoutForParallelInput, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="intermediate.dense", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"skip_bias_add": self.enable_bias_gelu_fused, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="output.dense", |
|
|
|
|
target_module=col_nn.LinearWithGradAccum, |
|
|
|
|
kwargs={ |
|
|
|
|
"seq_parallel_mode": sp_mode, |
|
|
|
|
"fp8_communication": self.shard_config.fp8_communication, |
|
|
|
|
"use_zbv": use_zbv, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|