|
|
|
@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.layer import ColoAttention
|
|
|
|
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
|
|
|
|
from colossalai.shardformer.layer import ColoAttention, RingAttention
|
|
|
|
|
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
|
|
|
|
|
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
|
|
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
|
|
|
|
|
|
from ..layer import dist_cross_entropy
|
|
|
|
@ -39,10 +40,16 @@ def _get_attention_mask(
|
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor],
|
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor],
|
|
|
|
|
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
|
|
|
|
batch_size, seq_len = hidden_states.shape[:2]
|
|
|
|
|
# Received input is already split for non-first pipeline stages,
|
|
|
|
|
# but attn mask isn't
|
|
|
|
|
batch_size = hidden_states.size(0)
|
|
|
|
|
seq_len = attention_mask.size(-1)
|
|
|
|
|
|
|
|
|
|
sp_mode = shard_config.sequence_parallelism_mode
|
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
|
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
|
|
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
|
|
|
|
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
|
|
|
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
|
|
|
if shard_config.enable_flash_attention:
|
|
|
|
|
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
|
|
@ -62,6 +69,7 @@ def _get_attention_mask(
|
|
|
|
|
encoder_attention_mask = {"attention_mask": None}
|
|
|
|
|
else:
|
|
|
|
|
encoder_attention_mask = None
|
|
|
|
|
|
|
|
|
|
# GPT2Attention mask.
|
|
|
|
|
past_key_values_length = 0
|
|
|
|
|
if past_key_values is not None and past_key_values[0] is not None:
|
|
|
|
@ -69,6 +77,7 @@ def _get_attention_mask(
|
|
|
|
|
if shard_config.enable_flash_attention:
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
|
|
|
|
|
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
|
|
|
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
|
|
|
|
hidden_states.dtype,
|
|
|
|
@ -123,6 +132,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
|
shard_config: ShardConfig = None,
|
|
|
|
|
force_sp_gather: Optional[bool] = True,
|
|
|
|
|
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
|
|
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
|
|
|
|
# Please refer to original code of transformers for more details.
|
|
|
|
@ -146,16 +156,15 @@ class GPT2PipelineForwards:
|
|
|
|
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
|
disable_pp = stage_manager is None
|
|
|
|
|
if disable_pp or stage_manager.is_first_stage():
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
|
elif input_ids is not None:
|
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
|
input_ids.shape[0]
|
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
|
inputs_embeds.shape[0]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
|
@ -176,7 +185,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
|
|
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
|
if disable_pp or stage_manager.is_first_stage():
|
|
|
|
|
if position_ids is None:
|
|
|
|
|
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
|
|
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
@ -190,9 +199,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
hidden_states = hidden_states + token_type_embeds
|
|
|
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
|
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
|
|
|
attn_kwargs, encoder_attention_mask = _get_attention_mask(
|
|
|
|
|
self,
|
|
|
|
|
shard_config,
|
|
|
|
|
hidden_states,
|
|
|
|
@ -215,23 +222,43 @@ class GPT2PipelineForwards:
|
|
|
|
|
|
|
|
|
|
# split the input tensor along sequence dimension
|
|
|
|
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
|
|
|
if shard_config and shard_config.enable_sequence_parallelism:
|
|
|
|
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
|
|
|
|
sp_mode = shard_config.sequence_parallelism_mode
|
|
|
|
|
sp_group = shard_config.sequence_parallel_process_group
|
|
|
|
|
if disable_pp or stage_manager.is_first_stage():
|
|
|
|
|
# Ring Attention's special zigzag batch processing
|
|
|
|
|
if sp_mode == "ring_attn":
|
|
|
|
|
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
|
|
|
|
if not attention_mask.bool().all():
|
|
|
|
|
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
|
|
|
|
attention_mask, sp_group, hidden_states, position_ids
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
|
|
|
|
# Other sp modes
|
|
|
|
|
else:
|
|
|
|
|
if sp_mode == "split_gather":
|
|
|
|
|
hidden_states = split_forward_gather_backward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
dim=1,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
elif sp_mode == "ring_attn":
|
|
|
|
|
# Later stages already received split hidden states
|
|
|
|
|
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
|
|
|
|
del attention_mask
|
|
|
|
|
|
|
|
|
|
# Going through held blocks.
|
|
|
|
|
if disable_pp:
|
|
|
|
|
start_idx, end_idx = 0, len(self.h)
|
|
|
|
|
else:
|
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
|
|
|
|
|
|
|
|
for i in range(start_idx, end_idx):
|
|
|
|
|
block = self.h[i]
|
|
|
|
|
torch.cuda.set_device(hidden_states.device)
|
|
|
|
|
# Ensure that attention_mask is always on the same device as hidden_states
|
|
|
|
|
if torch.is_tensor(attention_mask):
|
|
|
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
|
if torch.is_tensor(attn_kwargs):
|
|
|
|
|
attn_kwargs = attn_kwargs.to(hidden_states.device)
|
|
|
|
|
if isinstance(head_mask, torch.Tensor):
|
|
|
|
|
head_mask = head_mask.to(hidden_states.device)
|
|
|
|
|
if output_hidden_states:
|
|
|
|
@ -242,7 +269,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
block.__call__,
|
|
|
|
|
hidden_states,
|
|
|
|
|
None,
|
|
|
|
|
attention_mask,
|
|
|
|
|
attn_kwargs,
|
|
|
|
|
head_mask[i],
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask,
|
|
|
|
@ -253,7 +280,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
outputs = block(
|
|
|
|
|
hidden_states,
|
|
|
|
|
layer_past=None,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
attention_mask=attn_kwargs,
|
|
|
|
|
head_mask=head_mask[i],
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
@ -270,26 +297,25 @@ class GPT2PipelineForwards:
|
|
|
|
|
if self.config.add_cross_attention:
|
|
|
|
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
|
|
|
|
|
|
|
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
|
|
|
|
if shard_config and shard_config.enable_sequence_parallelism:
|
|
|
|
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
|
|
|
|
hidden_states = gather_forward_split_backward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
dim=1,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
# When sequence parallelism is done, gather the output tensor in forward and split it in backward
|
|
|
|
|
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
|
|
|
|
|
if disable_pp or stage_manager.is_last_stage():
|
|
|
|
|
if gather_output:
|
|
|
|
|
hidden_states = gather_sp_output(hidden_states, shard_config)
|
|
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
# gather_sp_output could've changed seq length.
|
|
|
|
|
input_shape = (*input_shape[:-1], hidden_states.size(-2))
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
|
if disable_pp or stage_manager.is_last_stage():
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
|
|
|
|
|
|
# Add last hidden state
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
|
if disable_pp or stage_manager.is_last_stage():
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return tuple(
|
|
|
|
|
v
|
|
|
|
@ -366,16 +392,28 @@ class GPT2PipelineForwards:
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
stage_index=stage_index,
|
|
|
|
|
shard_config=shard_config,
|
|
|
|
|
force_sp_gather=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# If not at the last stage, return hidden_states as in GPT2Model
|
|
|
|
|
if not stage_manager.is_last_stage():
|
|
|
|
|
disable_pp = stage_manager is None
|
|
|
|
|
if (not disable_pp) and (not stage_manager.is_last_stage()):
|
|
|
|
|
return {"hidden_states": outputs["hidden_states"]}
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
if shard_config.sequence_parallelism_mode == "ring_attn":
|
|
|
|
|
# Split labels in a zigzag fashion too
|
|
|
|
|
sp_group = shard_config.sequence_parallel_process_group
|
|
|
|
|
if not attention_mask.bool().all():
|
|
|
|
|
# [B, max_seqlen // sp_size]
|
|
|
|
|
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
|
|
|
|
else:
|
|
|
|
|
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
|
|
|
|
|
|
|
|
|
if labels is not None:
|
|
|
|
|
loss = dist_cross_entropy(
|
|
|
|
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
|
|
|
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
@ -770,7 +808,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpt2_flash_attention_forward():
|
|
|
|
|
def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
|
|
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
@ -817,6 +855,21 @@ def get_gpt2_flash_attention_forward():
|
|
|
|
|
if self.scale_attn_by_inverse_layer_idx:
|
|
|
|
|
scale /= float(self.layer_idx + 1)
|
|
|
|
|
dropout_p = self.attn_dropout.p if self.training else 0.0
|
|
|
|
|
|
|
|
|
|
sp_mode = shard_config.sequence_parallelism_mode
|
|
|
|
|
sp_group = shard_config.sequence_parallel_process_group
|
|
|
|
|
if sp_mode == "ring_attn":
|
|
|
|
|
attn_output = RingAttention.attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
sp_group,
|
|
|
|
|
**attention_mask,
|
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
|
scale=scale,
|
|
|
|
|
inner_ring_size=shard_config.inner_ring_size,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
|
|
|
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
|
|
|
|
attn_output = self.c_proj(attn_output)
|
|
|
|
@ -828,466 +881,6 @@ def get_gpt2_flash_attention_forward():
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
|
|
|
|
|
def forward(
|
|
|
|
|
self: GPT2Model,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
|
elif input_ids is not None:
|
|
|
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
|
input_ids.shape[0]
|
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
|
inputs_embeds.shape[0]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
|
|
|
|
|
|
if token_type_ids is not None:
|
|
|
|
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
|
|
|
if position_ids is not None:
|
|
|
|
|
position_ids = position_ids.view(-1, input_shape[-1])
|
|
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
|
past_length = 0
|
|
|
|
|
past_key_values = tuple([None] * len(self.h))
|
|
|
|
|
else:
|
|
|
|
|
past_length = past_key_values[0][0].size(-2)
|
|
|
|
|
if position_ids is None:
|
|
|
|
|
position_ids = torch.arange(
|
|
|
|
|
past_length,
|
|
|
|
|
input_shape[-1] + past_length,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
|
|
|
|
|
|
|
|
# Prepare head mask if needed
|
|
|
|
|
# 1.0 in head_mask indicate we keep the head
|
|
|
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
|
|
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
|
position_embeds = self.wpe(position_ids)
|
|
|
|
|
hidden_states = inputs_embeds + position_embeds
|
|
|
|
|
|
|
|
|
|
if token_type_ids is not None:
|
|
|
|
|
token_type_embeds = self.wte(token_type_ids)
|
|
|
|
|
hidden_states = hidden_states + token_type_embeds
|
|
|
|
|
|
|
|
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
|
|
|
|
|
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
|
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
|
|
|
self,
|
|
|
|
|
shard_config,
|
|
|
|
|
hidden_states,
|
|
|
|
|
past_key_values,
|
|
|
|
|
attention_mask,
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
if use_cache:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
|
)
|
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
|
|
presents = () if use_cache else None
|
|
|
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
|
# Model parallel
|
|
|
|
|
if self.model_parallel:
|
|
|
|
|
torch.cuda.set_device(hidden_states.device)
|
|
|
|
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
|
|
|
# Ensure that attention_mask is always on the same device as hidden_states
|
|
|
|
|
if torch.is_tensor(attention_mask):
|
|
|
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
|
if isinstance(head_mask, torch.Tensor):
|
|
|
|
|
head_mask = head_mask.to(hidden_states.device)
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
|
# None for past_key_value
|
|
|
|
|
return module(*inputs, use_cache, output_attentions)
|
|
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
|
|
outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
create_custom_forward(block),
|
|
|
|
|
hidden_states,
|
|
|
|
|
None,
|
|
|
|
|
attention_mask,
|
|
|
|
|
head_mask[i],
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
outputs = block(
|
|
|
|
|
hidden_states,
|
|
|
|
|
layer_past=layer_past,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
head_mask=head_mask[i],
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
if use_cache is True:
|
|
|
|
|
presents = presents + (outputs[1],)
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
|
|
if self.config.add_cross_attention:
|
|
|
|
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
|
|
|
|
|
|
|
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
|
|
|
|
if self.model_parallel:
|
|
|
|
|
for k, v in self.device_map.items():
|
|
|
|
|
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
|
|
|
|
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
|
|
|
|
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
|
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
|
# Add last hidden state
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return tuple(
|
|
|
|
|
v
|
|
|
|
|
for v in [
|
|
|
|
|
hidden_states,
|
|
|
|
|
presents,
|
|
|
|
|
all_hidden_states,
|
|
|
|
|
all_self_attentions,
|
|
|
|
|
all_cross_attentions,
|
|
|
|
|
]
|
|
|
|
|
if v is not None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
|
past_key_values=presents,
|
|
|
|
|
hidden_states=all_hidden_states,
|
|
|
|
|
attentions=all_self_attentions,
|
|
|
|
|
cross_attentions=all_cross_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
|
elif input_ids is not None:
|
|
|
|
|
input_shape = input_ids.size()
|
|
|
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
|
input_ids.shape[0]
|
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
|
inputs_embeds.shape[0]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
|
|
|
|
|
|
if token_type_ids is not None:
|
|
|
|
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
|
|
|
if position_ids is not None:
|
|
|
|
|
position_ids = position_ids.view(-1, input_shape[-1])
|
|
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
|
past_length = 0
|
|
|
|
|
past_key_values = tuple([None] * len(self.h))
|
|
|
|
|
else:
|
|
|
|
|
past_length = past_key_values[0][0].size(-2)
|
|
|
|
|
if position_ids is None:
|
|
|
|
|
position_ids = torch.arange(
|
|
|
|
|
past_length,
|
|
|
|
|
input_shape[-1] + past_length,
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
|
|
|
|
|
|
|
|
# Prepare head mask if needed
|
|
|
|
|
# 1.0 in head_mask indicate we keep the head
|
|
|
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
|
|
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
|
position_embeds = self.wpe(position_ids)
|
|
|
|
|
hidden_states = inputs_embeds + position_embeds
|
|
|
|
|
|
|
|
|
|
if token_type_ids is not None:
|
|
|
|
|
token_type_embeds = self.wte(token_type_ids)
|
|
|
|
|
hidden_states = hidden_states + token_type_embeds
|
|
|
|
|
|
|
|
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
attention_mask, encoder_attention_mask = _get_attention_mask(
|
|
|
|
|
self,
|
|
|
|
|
shard_config,
|
|
|
|
|
hidden_states,
|
|
|
|
|
past_key_values,
|
|
|
|
|
attention_mask,
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
if use_cache:
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
|
)
|
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
|
|
presents = () if use_cache else None
|
|
|
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
|
|
|
|
|
# split the input tensor along sequence dimension
|
|
|
|
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
|
|
|
|
hidden_states = split_forward_gather_backward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
dim=1,
|
|
|
|
|
process_group=shard_config.sequence_parallel_process_group,
|
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
|
# Model parallel
|
|
|
|
|
if self.model_parallel:
|
|
|
|
|
torch.cuda.set_device(hidden_states.device)
|
|
|
|
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
|
|
|
# Ensure that attention_mask is always on the same device as hidden_states
|
|
|
|
|
if torch.is_tensor(attention_mask):
|
|
|
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
|
if isinstance(head_mask, torch.Tensor):
|
|
|
|
|
head_mask = head_mask.to(hidden_states.device)
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
|
# None for past_key_value
|
|
|
|
|
return module(*inputs, use_cache, output_attentions)
|
|
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
|
|
outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
create_custom_forward(block),
|
|
|
|
|
hidden_states,
|
|
|
|
|
None,
|
|
|
|
|
attention_mask,
|
|
|
|
|
head_mask[i],
|
|
|
|
|
encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
outputs = block(
|
|
|
|
|
hidden_states,
|
|
|
|
|
layer_past=layer_past,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
head_mask=head_mask[i],
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
if use_cache is True:
|
|
|
|
|
presents = presents + (outputs[1],)
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
|
|
if self.config.add_cross_attention:
|
|
|
|
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
|
|
|
|
|
|
|
|
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
|
|
|
|
if self.model_parallel:
|
|
|
|
|
for k, v in self.device_map.items():
|
|
|
|
|
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
|
|
|
|
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
|
|
|
|
|
|
|
|
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
|
|
|
|
hidden_states = gather_forward_split_backward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
dim=1,
|
|
|
|
|
process_group=shard_config.sequence_parallel_process_group,
|
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
|
# Add last hidden state
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return tuple(
|
|
|
|
|
v
|
|
|
|
|
for v in [
|
|
|
|
|
hidden_states,
|
|
|
|
|
presents,
|
|
|
|
|
all_hidden_states,
|
|
|
|
|
all_self_attentions,
|
|
|
|
|
all_cross_attentions,
|
|
|
|
|
]
|
|
|
|
|
if v is not None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
|
past_key_values=presents,
|
|
|
|
|
hidden_states=all_hidden_states,
|
|
|
|
|
attentions=all_self_attentions,
|
|
|
|
|
cross_attentions=all_cross_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
from transformers import GPT2LMHeadModel
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self: GPT2LMHeadModel,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
|
|
|
|
r"""
|
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
|
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
|
|
"""
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
transformer_outputs = self.transformer(
|
|
|
|
|
input_ids,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
token_type_ids=token_type_ids,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
head_mask=head_mask,
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
return_dict=return_dict,
|
|
|
|
|
)
|
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
|
|
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
loss = dist_cross_entropy(
|
|
|
|
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
|
|
|
loss=loss,
|
|
|
|
|
logits=lm_logits,
|
|
|
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
|
cross_attentions=transformer_outputs.cross_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_gpt2_mlp_forward():
|
|
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
|
|
|
|
|
|
|
|
|