[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)

* support gpt2 seq parallel with pp/dp/tp

* fix a bug when waiting for stream done

* delete unused gpt2_seq file
pull/4455/head
Bin Jia 2023-08-18 11:21:53 +08:00 committed by GitHub
parent a78daf6180
commit 7c8be77081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 268 additions and 240 deletions

View File

@ -235,6 +235,10 @@ class HybridParallelPlugin(PipelinePluginBase):
assert dist.get_world_size() % ( assert dist.get_world_size() % (
tp_size * pp_size tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
# TODO(ver217): support zero # TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet' assert zero_stage == 0, 'zero is not support yet'
self.tp_size = tp_size self.tp_size = tp_size

View File

@ -239,6 +239,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
torch.cuda.current_stream().wait_stream(calculate_stream) torch.cuda.current_stream().wait_stream(calculate_stream)
gather_handle.wait()
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream): with torch.cuda.stream(calculate_stream):
@ -249,6 +250,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
grad_weight = grad_output.t().matmul(input_parallel) grad_weight = grad_output.t().matmul(input_parallel)
torch.cuda.current_stream().wait_stream(calculate_stream) torch.cuda.current_stream().wait_stream(calculate_stream)
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None return output, grad_weight, grad_bias, None, None, None, None

View File

@ -21,6 +21,8 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
class GPT2PipelineForwards: class GPT2PipelineForwards:
@ -47,7 +49,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
@ -159,6 +162,13 @@ class GPT2PipelineForwards:
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
# Going through held blocks. # Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx): for i in range(start_idx, end_idx):
@ -212,6 +222,12 @@ class GPT2PipelineForwards:
if self.config.add_cross_attention: if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
@ -257,7 +273,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@ -285,7 +302,8 @@ class GPT2PipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
@ -335,7 +353,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
r""" r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
@ -367,7 +386,8 @@ class GPT2PipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
@ -421,7 +441,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
@ -449,7 +470,8 @@ class GPT2PipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
@ -508,7 +530,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -534,7 +557,8 @@ class GPT2PipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
@ -578,7 +602,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -613,7 +638,8 @@ class GPT2PipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model # If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
@ -696,7 +722,6 @@ def get_gpt2_flash_attention_forward():
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size() _, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"): if not hasattr(self, "q_attn"):
@ -753,3 +778,210 @@ def get_gpt2_flash_attention_forward():
return outputs return outputs
return forward return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward

View File

@ -1,222 +0,0 @@
# this code is modified from transformers.models.gpt2.modeling_gpt2
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
# TODO: put all contents in `gpt2.py` and make it compatible with pipeline
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward

View File

@ -6,8 +6,7 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
@ -50,8 +49,6 @@ class GPT2Policy(Policy):
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
]) ])
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -126,6 +123,7 @@ class GPT2Policy(Policy):
}) })
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
@ -169,7 +167,13 @@ class GPT2Policy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

View File

@ -105,10 +105,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32', 'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, { }, {
'tp_size': 4, 'tp_size': 4,
'pp_size': 1, 'pp_size': 1,
'enable_all_optimization': False, 'enable_all_optimization': True,
'use_lazy_init': True, 'use_lazy_init': True,
'enable_sequence_parallelism': True, 'enable_sequence_parallelism': True,
'precision': 'fp32', 'precision': 'fp32',