diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fc13aca79..f1f48273c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) # calculate gradient in calculate_stream @@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None): # all gather input_ = input_.contiguous() - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 5bd1c531c..d88661953 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward class BertPipelineForwards: @@ -56,6 +58,7 @@ class BertPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" @@ -177,6 +180,14 @@ class BertPipelineForwards: start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -223,11 +234,17 @@ class BertPipelineForwards: all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + sequence_output = hidden_states if hidden_states is not None else None if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -268,6 +285,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) @@ -294,6 +312,7 @@ class BertPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -350,6 +369,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -404,7 +424,8 @@ class BertPipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -457,6 +478,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -491,6 +513,7 @@ class BertPipelineForwards: hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -532,6 +555,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **kwargs, ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -594,7 +618,8 @@ class BertPipelineForwards: return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -636,6 +661,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -666,7 +692,8 @@ class BertPipelineForwards: return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -726,6 +753,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -742,21 +770,20 @@ class BertPipelineForwards: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -799,6 +826,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -843,6 +871,7 @@ class BertPipelineForwards: hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -886,6 +915,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # NOTE: the arg start_position and end_position are used only for the last stage r""" @@ -909,21 +939,20 @@ class BertPipelineForwards: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -1101,3 +1130,150 @@ def get_jit_fused_bert_output_forward(): return hidden_states return forward + + +def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + embedding_output = split_forward_gather_backward(embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + sequence_output = gather_forward_split_backward(sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ace9ada39..fe091c658 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, + bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -47,13 +48,14 @@ class BertPolicy(Policy): from transformers.models.bert.modeling_bert import ( BertEmbeddings, BertLayer, + BertModel, BertOutput, BertSelfAttention, BertSelfOutput, ) policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -69,14 +71,17 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -85,6 +90,7 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -93,10 +99,12 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -115,6 +123,12 @@ class BertPolicy(Policy): ) ]) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BertModel) + # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer @@ -205,7 +219,13 @@ class BertPolicy(Policy): layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)