[shardformer] bert support sequence parallel. (#4455)

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel
pull/4484/head
flybird11111 2023-08-18 18:04:55 +08:00 committed by GitHub
parent 0ecd71e041
commit a27e0bb494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 234 additions and 42 deletions

View File

@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# do all gather in default stream # do all gather in default stream
input_ = input_.contiguous() input_ = input_.contiguous()
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient in calculate_stream # calculate gradient in calculate_stream
@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None):
# all gather # all gather
input_ = input_.contiguous() input_ = input_.contiguous()
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group) torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat # concat

View File

@ -1,6 +1,6 @@
import math import math
import warnings import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
class BertPipelineForwards: class BertPipelineForwards:
@ -56,6 +58,7 @@ class BertPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
# TODO(jianghai): add explaination of the output here. # TODO(jianghai): add explaination of the output here.
r""" r"""
@ -177,6 +180,14 @@ class BertPipelineForwards:
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
# layer_outputs # layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None layer_outputs = hidden_states if hidden_states is not None else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0: if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask encoder_attention_mask = encoder_extended_attention_mask
@ -223,11 +234,17 @@ class BertPipelineForwards:
all_cross_attentions = all_cross_attentions + \ all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],) (layer_outputs[2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
# end of a stage loop # end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None sequence_output = hidden_states if hidden_states is not None else None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@ -268,6 +285,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -294,6 +312,7 @@ class BertPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None, hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
@ -350,6 +369,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -404,7 +424,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None, hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
all_self_attentions = None all_self_attentions = None
@ -457,6 +478,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -491,6 +513,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
@ -532,6 +555,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs, **kwargs,
): ):
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
@ -594,7 +618,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
@ -636,6 +661,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -666,7 +692,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
@ -726,6 +753,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -742,21 +770,20 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(self.bert,
self.bert, input_ids,
input_ids, attention_mask=attention_mask,
attention_mask=attention_mask, token_type_ids=token_type_ids,
token_type_ids=token_type_ids, position_ids=position_ids,
position_ids=position_ids, head_mask=head_mask,
head_mask=head_mask, inputs_embeds=inputs_embeds,
inputs_embeds=inputs_embeds, output_attentions=output_attentions,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
output_hidden_states=output_hidden_states, return_dict=return_dict,
return_dict=return_dict, hidden_states=hidden_states,
hidden_states=hidden_states, stage_manager=stage_manager,
stage_manager=stage_manager, stage_index=stage_index,
stage_index=stage_index, shard_config=shard_config)
)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
sequence_output = outputs[0] sequence_output = outputs[0]
@ -799,6 +826,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -843,6 +871,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
@ -886,6 +915,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
# NOTE: the arg start_position and end_position are used only for the last stage # NOTE: the arg start_position and end_position are used only for the last stage
r""" r"""
@ -909,21 +939,20 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(self.bert,
self.bert, input_ids,
input_ids, attention_mask=attention_mask,
attention_mask=attention_mask, token_type_ids=token_type_ids,
token_type_ids=token_type_ids, position_ids=position_ids,
position_ids=position_ids, head_mask=head_mask,
head_mask=head_mask, inputs_embeds=inputs_embeds,
inputs_embeds=inputs_embeds, output_attentions=output_attentions,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
output_hidden_states=output_hidden_states, return_dict=return_dict,
return_dict=return_dict, hidden_states=hidden_states,
hidden_states=hidden_states, stage_manager=stage_manager,
stage_manager=stage_manager, stage_index=stage_index,
stage_index=stage_index, shard_config=shard_config)
)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
sequence_output = outputs[0] sequence_output = outputs[0]
@ -1101,3 +1130,150 @@ def get_jit_fused_bert_output_forward():
return hidden_states return hidden_states
return forward return forward
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
return forward

View File

@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.bert import ( from ..modeling.bert import (
BertPipelineForwards, BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward, get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward, get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward, get_jit_fused_bert_self_output_forward,
@ -47,13 +48,14 @@ class BertPolicy(Policy):
from transformers.models.bert.modeling_bert import ( from transformers.models.bert.modeling_bert import (
BertEmbeddings, BertEmbeddings,
BertLayer, BertLayer,
BertModel,
BertOutput, BertOutput,
BertSelfAttention, BertSelfAttention,
BertSelfOutput, BertSelfOutput,
) )
policy = {} policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size": "attention.self.all_head_size":
@ -69,14 +71,17 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.key", suffix="attention.self.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.value", suffix="attention.self.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.dropout", suffix="attention.self.dropout",
@ -85,6 +90,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
@ -93,10 +99,12 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
@ -115,6 +123,12 @@ class BertPolicy(Policy):
) )
]) ])
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
# Handle bert layer # Handle bert layer
@ -205,7 +219,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(description=method_replacement,
policy=policy, policy=policy,
target_key=model_cls) target_key=model_cls)