mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] bert support sequence parallel. (#4455)
* [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallelpull/4484/head
parent
0ecd71e041
commit
a27e0bb494
|
@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
|
@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# do all gather in default stream
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
|
||||
# calculate gradient in calculate_stream
|
||||
|
@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None):
|
|||
|
||||
# all gather
|
||||
input_ = input_.contiguous()
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
# concat
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
|
||||
|
||||
class BertPipelineForwards:
|
||||
|
@ -56,6 +58,7 @@ class BertPipelineForwards:
|
|||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
# TODO(jianghai): add explaination of the output here.
|
||||
r"""
|
||||
|
@ -177,6 +180,14 @@ class BertPipelineForwards:
|
|||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
# layer_outputs
|
||||
layer_outputs = hidden_states if hidden_states is not None else None
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
if stage_manager.is_first_stage() and idx == 0:
|
||||
encoder_attention_mask = encoder_extended_attention_mask
|
||||
|
@ -223,11 +234,17 @@ class BertPipelineForwards:
|
|||
all_cross_attentions = all_cross_attentions + \
|
||||
(layer_outputs[2],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
# end of a stage loop
|
||||
sequence_output = layer_outputs[0] if layer_outputs is not None else None
|
||||
sequence_output = hidden_states if hidden_states is not None else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
@ -268,6 +285,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -294,6 +312,7 @@ class BertPipelineForwards:
|
|||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states if hidden_states is not None else None,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
|
@ -350,6 +369,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
|
@ -404,7 +424,8 @@ class BertPipelineForwards:
|
|||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states if hidden_states is not None else None,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
past_key_values = None
|
||||
all_hidden_states = None
|
||||
all_self_attentions = None
|
||||
|
@ -457,6 +478,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
@ -491,6 +513,7 @@ class BertPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -532,6 +555,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
|
@ -594,7 +618,8 @@ class BertPipelineForwards:
|
|||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
|
@ -636,6 +661,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
|
@ -666,7 +692,8 @@ class BertPipelineForwards:
|
|||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
|
@ -726,6 +753,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
@ -742,21 +770,20 @@ class BertPipelineForwards:
|
|||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
@ -799,6 +826,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
|
@ -843,6 +871,7 @@ class BertPipelineForwards:
|
|||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
|
@ -886,6 +915,7 @@ class BertPipelineForwards:
|
|||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
# NOTE: the arg start_position and end_position are used only for the last stage
|
||||
r"""
|
||||
|
@ -909,21 +939,20 @@ class BertPipelineForwards:
|
|||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
outputs = BertPipelineForwards.bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
@ -1101,3 +1130,150 @@ def get_jit_fused_bert_output_forward():
|
|||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
embedding_output = split_forward_gather_backward(embedding_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
sequence_output = gather_forward_split_backward(sequence_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
|
|||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
|
@ -47,13 +48,14 @@ class BertPolicy(Policy):
|
|||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.self.all_head_size":
|
||||
|
@ -69,14 +71,17 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
|
@ -85,6 +90,7 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -93,10 +99,12 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -115,6 +123,12 @@ class BertPolicy(Policy):
|
|||
)
|
||||
])
|
||||
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=BertModel)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert layer
|
||||
|
@ -205,7 +219,13 @@ class BertPolicy(Policy):
|
|||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
|
Loading…
Reference in New Issue