mirror of https://github.com/hpcaitech/ColossalAI
[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)
* support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq filepull/4455/head
parent
a78daf6180
commit
7c8be77081
|
@ -235,6 +235,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
assert dist.get_world_size() % (
|
assert dist.get_world_size() % (
|
||||||
tp_size * pp_size
|
tp_size * pp_size
|
||||||
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
|
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
|
||||||
|
|
||||||
|
if enable_sequence_parallelism:
|
||||||
|
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
|
||||||
|
|
||||||
# TODO(ver217): support zero
|
# TODO(ver217): support zero
|
||||||
assert zero_stage == 0, 'zero is not support yet'
|
assert zero_stage == 0, 'zero is not support yet'
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
|
|
@ -239,6 +239,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||||
|
gather_handle.wait()
|
||||||
|
|
||||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
with torch.cuda.stream(calculate_stream):
|
with torch.cuda.stream(calculate_stream):
|
||||||
|
@ -249,6 +250,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
grad_weight = grad_output.t().matmul(input_parallel)
|
grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||||
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None
|
return output, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
class GPT2PipelineForwards:
|
class GPT2PipelineForwards:
|
||||||
|
@ -47,7 +49,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
|
||||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
|
@ -159,6 +162,13 @@ class GPT2PipelineForwards:
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# split the input tensor along sequence dimension
|
||||||
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
hidden_states = split_forward_gather_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
# Going through held blocks.
|
# Going through held blocks.
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
for i in range(start_idx, end_idx):
|
for i in range(start_idx, end_idx):
|
||||||
|
@ -212,6 +222,12 @@ class GPT2PipelineForwards:
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
@ -257,7 +273,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
@ -285,7 +302,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index)
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
if not stage_manager.is_last_stage():
|
||||||
|
@ -335,7 +353,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
||||||
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
||||||
|
@ -367,7 +386,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index)
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
if not stage_manager.is_last_stage():
|
||||||
|
@ -421,7 +441,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
@ -449,7 +470,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index)
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
if not stage_manager.is_last_stage():
|
||||||
|
@ -508,7 +530,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
@ -534,7 +557,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index)
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
if not stage_manager.is_last_stage():
|
||||||
|
@ -578,7 +602,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
@ -613,7 +638,8 @@ class GPT2PipelineForwards:
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
stage_manager=stage_manager,
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index)
|
stage_index=stage_index,
|
||||||
|
shard_config=shard_config)
|
||||||
|
|
||||||
# If not at the last stage, return hidden_states as in GPT2Model
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
if not stage_manager.is_last_stage():
|
if not stage_manager.is_last_stage():
|
||||||
|
@ -696,7 +722,6 @@ def get_gpt2_flash_attention_forward():
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||||
_, tgt_len, _ = hidden_states.size()
|
_, tgt_len, _ = hidden_states.size()
|
||||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
if not hasattr(self, "q_attn"):
|
if not hasattr(self, "q_attn"):
|
||||||
|
@ -753,3 +778,210 @@ def get_gpt2_flash_attention_forward():
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# GPT2Attention mask.
|
||||||
|
if attention_mask is not None:
|
||||||
|
if batch_size <= 0:
|
||||||
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# split the input tensor along sequence dimension
|
||||||
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
|
hidden_states = split_forward_gather_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
# Model parallel
|
||||||
|
if self.model_parallel:
|
||||||
|
torch.cuda.set_device(hidden_states.device)
|
||||||
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||||
|
if layer_past is not None:
|
||||||
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||||
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
if isinstance(head_mask, torch.Tensor):
|
||||||
|
head_mask = head_mask.to(hidden_states.device)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
|
if self.model_parallel:
|
||||||
|
for k, v in self.device_map.items():
|
||||||
|
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||||
|
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||||
|
|
||||||
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
# Add last hidden state
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||||
|
if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -1,222 +0,0 @@
|
||||||
# this code is modified from transformers.models.gpt2.modeling_gpt2
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: put all contents in `gpt2.py` and make it compatible with pipeline
|
|
||||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (output_hidden_states
|
|
||||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
||||||
if position_ids is not None:
|
|
||||||
position_ids = position_ids.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# GPT2Attention mask.
|
|
||||||
if attention_mask is not None:
|
|
||||||
if batch_size <= 0:
|
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
||||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
||||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
||||||
else:
|
|
||||||
encoder_attention_mask = None
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
|
||||||
hidden_states = hidden_states + token_type_embeds
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
# split the input tensor along sequence dimension
|
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
||||||
hidden_states = split_forward_gather_backward(hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group)
|
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
||||||
# Model parallel
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(hidden_states.device)
|
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
||||||
if layer_past is not None:
|
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
|
||||||
if isinstance(head_mask, torch.Tensor):
|
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, use_cache, output_attentions)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
attention_mask,
|
|
||||||
head_mask[i],
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs = block(
|
|
||||||
hidden_states,
|
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask[i],
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
||||||
if self.config.add_cross_attention:
|
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
|
||||||
if self.model_parallel:
|
|
||||||
for k, v in self.device_map.items():
|
|
||||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
|
||||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
|
||||||
hidden_states = gather_forward_split_backward(hidden_states,
|
|
||||||
dim=1,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group)
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(output_shape)
|
|
||||||
# Add last hidden state
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
|
||||||
if v is not None)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=presents,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attentions,
|
|
||||||
cross_attentions=all_cross_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
return forward
|
|
|
@ -6,8 +6,7 @@ from torch import Tensor, nn
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
|
||||||
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
|
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -50,8 +49,6 @@ class GPT2Policy(Policy):
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
|
||||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
|
||||||
|
|
||||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
@ -126,6 +123,7 @@ class GPT2Policy(Policy):
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||||
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||||
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
||||||
|
|
||||||
|
@ -169,7 +167,13 @@ class GPT2Policy(Policy):
|
||||||
|
|
||||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
method_replacement = {
|
||||||
|
'forward':
|
||||||
|
partial(new_forward,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
stage_index=stage_index,
|
||||||
|
shard_config=self.shard_config)
|
||||||
|
}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -105,10 +105,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
'enable_all_optimization': True,
|
'enable_all_optimization': True,
|
||||||
'use_lazy_init': False,
|
'use_lazy_init': False,
|
||||||
'precision': 'fp32',
|
'precision': 'fp32',
|
||||||
|
}, {
|
||||||
|
'tp_size': 2,
|
||||||
|
'pp_size': 2,
|
||||||
|
'num_microbatches': 4,
|
||||||
|
'enable_all_optimization': True,
|
||||||
|
'use_lazy_init': True,
|
||||||
|
'enable_sequence_parallelism': True,
|
||||||
|
'precision': 'fp32',
|
||||||
}, {
|
}, {
|
||||||
'tp_size': 4,
|
'tp_size': 4,
|
||||||
'pp_size': 1,
|
'pp_size': 1,
|
||||||
'enable_all_optimization': False,
|
'enable_all_optimization': True,
|
||||||
'use_lazy_init': True,
|
'use_lazy_init': True,
|
||||||
'enable_sequence_parallelism': True,
|
'enable_sequence_parallelism': True,
|
||||||
'precision': 'fp32',
|
'precision': 'fp32',
|
||||||
|
|
Loading…
Reference in New Issue