@ -84,17 +84,26 @@ class BertPolicy(Policy):
SubModuleReplacementDescription (
suffix = " attention.self.query " ,
target_module = col_nn . Linear1D_Col ,
kwargs = { " seq_parallel " : use_sequence_parallel , " overlap " : overlap } ,
kwargs = {
" seq_parallel " : use_sequence_parallel ,
" overlap " : overlap ,
} ,
) ,
SubModuleReplacementDescription (
suffix = " attention.self.key " ,
target_module = col_nn . Linear1D_Col ,
kwargs = { " seq_parallel " : use_sequence_parallel , " overlap " : overlap } ,
kwargs = {
" seq_parallel " : use_sequence_parallel ,
" overlap " : overlap ,
} ,
) ,
SubModuleReplacementDescription (
suffix = " attention.self.value " ,
target_module = col_nn . Linear1D_Col ,
kwargs = { " seq_parallel " : use_sequence_parallel , " overlap " : overlap } ,
kwargs = {
" seq_parallel " : use_sequence_parallel ,
" overlap " : overlap ,
} ,
) ,
SubModuleReplacementDescription (
suffix = " attention.self.dropout " ,
@ -112,7 +121,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription (
suffix = " intermediate.dense " ,
target_module = col_nn . Linear1D_Col ,
kwargs = { " seq_parallel " : use_sequence_parallel , " overlap " : overlap } ,
kwargs = {
" seq_parallel " : use_sequence_parallel ,
" overlap " : overlap ,
} ,
) ,
SubModuleReplacementDescription (
suffix = " output.dense " ,
@ -214,7 +226,9 @@ class BertPolicy(Policy):
if self . shard_config . enable_tensor_parallelism :
self . append_or_create_submodule_replacement (
description = SubModuleReplacementDescription (
suffix = " decoder " , target_module = col_nn . Linear1D_Col , kwargs = { " gather_output " : True }
suffix = " decoder " ,
target_module = col_nn . Linear1D_Col ,
kwargs = { " gather_output " : True } ,
) ,
policy = base_policy ,
target_key = BertLMPredictionHead ,
@ -241,7 +255,9 @@ class BertPolicy(Policy):
" _load_from_state_dict " : col_nn . ParallelModule . _load_from_state_dict ,
}
self . append_or_create_method_replacement (
description = method_replacement , policy = base_policy , target_key = BertLMPredictionHead
description = method_replacement ,
policy = base_policy ,
target_key = BertLMPredictionHead ,
)
return base_policy
@ -264,24 +280,32 @@ class BertPolicy(Policy):
if stage_manager . is_interleave :
layers_per_stage = self . distribute_layers (
len ( module . encoder . layer ) , stage_manager . num_stages * stage_manager . num_model_chunks
len ( module . encoder . layer ) ,
stage_manager . num_stages * stage_manager . num_model_chunks ,
)
stage_manager . stage_indices = Policy . get_stage_index (
stage_manager . stage_indices = self . get_stage_index (
layers_per_stage ,
stage_manager . stage ,
num_model_chunks = stage_manager . num_model_chunks ,
num_stages = stage_manager . num_stages ,
)
method_replacement = {
" forward " : partial ( new_forward , stage_manager = stage_manager , shard_config = self . shard_config )
" forward " : partial (
new_forward ,
stage_manager = stage_manager ,
shard_config = self . shard_config ,
)
}
else :
layers_per_stage = Policy . distribute_layers ( len ( module . encoder . layer ) , stage_manager . num_stages )
stage_index = Policy . get_stage_index ( layers_per_stage , stage_manager . stage )
layers_per_stage = self . distribute_layers ( len ( module . encoder . layer ) , stage_manager . num_stages )
stage_index = self . get_stage_index ( layers_per_stage , stage_manager . stage )
method_replacement = {
" forward " : partial (
new_forward , stage_manager = stage_manager , stage_index = stage_index , shard_config = self . shard_config
new_forward ,
stage_manager = stage_manager ,
stage_index = stage_index ,
shard_config = self . shard_config ,
)
}
@ -301,9 +325,10 @@ class BertPolicy(Policy):
if stage_manager . is_interleave :
assert stage_manager . num_model_chunks is not None
layers_per_stage = self . distribute_layers (
len ( module . encoder . layer ) , stage_manager . num_stages * stage_manager . num_model_chunks
len ( module . encoder . layer ) ,
stage_manager . num_stages * stage_manager . num_model_chunks ,
)
stage_indices = Policy . get_stage_index (
stage_indices = self . get_stage_index (
layers_per_stage ,
stage_manager . stage ,
num_model_chunks = stage_manager . num_model_chunks ,
@ -320,7 +345,7 @@ class BertPolicy(Policy):
layers_per_stage = self . distribute_layers ( len ( module . encoder . layer ) , stage_manager . num_stages )
if stage_manager . is_first_stage ( ) :
held_layers . append ( module . embeddings )
start_idx , end_idx = Policy . get_stage_index ( layers_per_stage , stage_manager . stage )
start_idx , end_idx = self . get_stage_index ( layers_per_stage , stage_manager . stage )
held_layers . extend ( module . encoder . layer [ start_idx : end_idx ] )
if stage_manager . is_last_stage ( ) :
held_layers . append ( module . pooler )
@ -336,7 +361,9 @@ class BertModelPolicy(BertPolicy):
if self . pipeline_stage_manager :
self . set_pipeline_forward (
model_cls = BertModel , new_forward = BertPipelineForwards . bert_model_forward , policy = policy
model_cls = BertModel ,
new_forward = BertPipelineForwards . bert_model_forward ,
policy = policy ,
)
return policy
@ -399,7 +426,9 @@ class BertLMHeadModelPolicy(BertPolicy):
if self . pipeline_stage_manager :
self . set_pipeline_forward (
model_cls = BertLMHeadModel , new_forward = BertPipelineForwards . bert_lm_head_model_forward , policy = policy
model_cls = BertLMHeadModel ,
new_forward = BertPipelineForwards . bert_lm_head_model_forward ,
policy = policy ,
)
return policy
@ -437,7 +466,9 @@ class BertForMaskedLMPolicy(BertPolicy):
if self . pipeline_stage_manager :
self . set_pipeline_forward (
model_cls = BertForMaskedLM , new_forward = BertPipelineForwards . bert_for_masked_lm_forward , policy = policy
model_cls = BertForMaskedLM ,
new_forward = BertPipelineForwards . bert_for_masked_lm_forward ,
policy = policy ,
)
return policy