[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)

* change for transformers loggers

* add forward for GPT2ForQuestionAnswering

* fix assert

* fix torchrec test
pull/4445/head
Baizhou Zhang 1 year ago committed by Hongxin Liu
parent d9be0472ef
commit 2a2eacfaf1

@ -68,6 +68,8 @@ _POLICY_LIST = {
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":

@ -1,6 +1,4 @@
import logging
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
@ -298,6 +296,33 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
return self.model
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
module_policy = super().module_policy()
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared_params in gpt2 for QA.'''
return []
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
@ -391,6 +416,8 @@ class GPT2PipelineForwards:
# Please refer to original code of transformers for more details.
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -416,7 +443,8 @@ class GPT2PipelineForwards:
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
assert hidden_states is not None
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
@ -478,21 +506,21 @@ class GPT2PipelineForwards:
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.')
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logging.warning('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logging.warning('use_cache=True is not supported for pipeline models at the moment.')
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
if self.gradient_checkpointing and self.training:
if use_cache:
logging.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
@ -751,6 +779,94 @@ class GPT2PipelineForwards:
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_question_answering_forward(
self: 'GPT2ForQuestionAnswering',
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import QuestionAnsweringModelOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_token_classification_forward(
self: 'GPT2ForTokenClassification',
@ -852,6 +968,8 @@ class GPT2PipelineForwards:
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
@ -892,7 +1010,7 @@ class GPT2PipelineForwards:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logging.warning(
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")

@ -1 +1 @@
#from .torchrec import *
from .torchrec import *

@ -29,6 +29,17 @@ def data_gen_for_lm():
return data
def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,

@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']

Loading…
Cancel
Save