|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions |
|
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model |
|
|
|
|
from transformers.utils import logging |
|
|
|
@ -10,41 +9,41 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPT2PipelineForwards: |
|
|
|
|
''' |
|
|
|
|
""" |
|
|
|
|
This class serves as a micro library for forward function substitution of GPT2 models |
|
|
|
|
under pipeline setting. |
|
|
|
|
''' |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def gpt2_model_forward( |
|
|
|
|
self: GPT2Model, |
|
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
|
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
|
|
|
|
|
|
|
|
self: GPT2Model, |
|
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
|
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
|
|
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. |
|
|
|
|
# Please refer to original code of transformers for more details. |
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
# Preprocess passed in arguments |
|
|
|
|
if output_attentions: |
|
|
|
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') |
|
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") |
|
|
|
|
output_attentions = False |
|
|
|
|
if output_hidden_states: |
|
|
|
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') |
|
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") |
|
|
|
|
output_hidden_states = False |
|
|
|
|
|
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
@ -96,7 +95,7 @@ class GPT2PipelineForwards:
|
|
|
|
|
# positions we want to attend and the dtype's smallest value for masked positions. |
|
|
|
|
# Since we are adding it to the raw scores before the softmax, this is |
|
|
|
|
# effectively the same as removing these entirely. |
|
|
|
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility |
|
|
|
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility |
|
|
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |
|
|
|
|
|
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention |
|
|
|
@ -137,7 +136,8 @@ class GPT2PipelineForwards:
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
if use_cache: |
|
|
|
|
logger.warning_once( |
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") |
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
|
|
) |
|
|
|
|
use_cache = False |
|
|
|
|
|
|
|
|
|
presents = () if use_cache else None |
|
|
|
@ -166,7 +166,6 @@ class GPT2PipelineForwards:
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
|
|
|
|
|
def create_custom_forward(module): |
|
|
|
|
|
|
|
|
|
def custom_forward(*inputs): |
|
|
|
|
# None for past_key_value |
|
|
|
|
return module(*inputs, use_cache, output_attentions) |
|
|
|
@ -218,61 +217,64 @@ class GPT2PipelineForwards:
|
|
|
|
|
if output_hidden_states: |
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
|
|
|
|
return {'hidden_states': hidden_states, 'past_key_values': presents} |
|
|
|
|
return {"hidden_states": hidden_states, "past_key_values": presents} |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def gpt2_lmhead_model_forward( |
|
|
|
|
self: GPT2LMHeadModel, |
|
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
|
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: |
|
|
|
|
self: GPT2LMHeadModel, |
|
|
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
|
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: |
|
|
|
|
r""" |
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
|
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
|
|
|
|
|
|
|
|
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. |
|
|
|
|
Please refer to original code of transformers for more details. |
|
|
|
|
""" |
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
|
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
|
|
|
|
|
|
|
|
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. |
|
|
|
|
Please refer to original code of transformers for more details. |
|
|
|
|
""" |
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
|
# If is first stage and after warmup, go throught lm_head first |
|
|
|
|
if stage_manager.is_first_stage() and hidden_states is not None: |
|
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
|
return {'logits': lm_logits} |
|
|
|
|
return {"logits": lm_logits} |
|
|
|
|
|
|
|
|
|
# Not first stage or before warmup, go through gpt2 model |
|
|
|
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, |
|
|
|
|
input_ids, |
|
|
|
|
past_key_values=past_key_values, |
|
|
|
|
attention_mask=attention_mask, |
|
|
|
|
token_type_ids=token_type_ids, |
|
|
|
|
position_ids=position_ids, |
|
|
|
|
head_mask=head_mask, |
|
|
|
|
inputs_embeds=inputs_embeds, |
|
|
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
|
|
use_cache=use_cache, |
|
|
|
|
output_attentions=output_attentions, |
|
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
|
return_dict=return_dict, |
|
|
|
|
stage_manager=stage_manager, |
|
|
|
|
hidden_states=hidden_states, |
|
|
|
|
stage_index=stage_index) |
|
|
|
|
outputs = GPT2PipelineForwards.gpt2_model_forward( |
|
|
|
|
self.transformer, |
|
|
|
|
input_ids, |
|
|
|
|
past_key_values=past_key_values, |
|
|
|
|
attention_mask=attention_mask, |
|
|
|
|
token_type_ids=token_type_ids, |
|
|
|
|
position_ids=position_ids, |
|
|
|
|
head_mask=head_mask, |
|
|
|
|
inputs_embeds=inputs_embeds, |
|
|
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
|
|
use_cache=use_cache, |
|
|
|
|
output_attentions=output_attentions, |
|
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
|
return_dict=return_dict, |
|
|
|
|
stage_manager=stage_manager, |
|
|
|
|
hidden_states=hidden_states, |
|
|
|
|
stage_index=stage_index, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return outputs |
|
|
|
|