[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)

* change for transformers loggers

* add forward for GPT2ForQuestionAnswering

* fix assert

* fix torchrec test
pull/4445/head
Baizhou Zhang 2023-07-19 09:28:27 +08:00 committed by Hongxin Liu
parent d9be0472ef
commit 2a2eacfaf1
5 changed files with 147 additions and 11 deletions

View File

@ -68,6 +68,8 @@ _POLICY_LIST = {
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":

View File

@ -1,6 +1,4 @@
import logging
from functools import partial from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
@ -298,6 +296,33 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
return self.model return self.model
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
module_policy = super().module_policy()
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared_params in gpt2 for QA.'''
return []
# GPT2ForTokenClassification # GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy): class GPT2ForTokenClassificationPolicy(GPT2Policy):
@ -391,6 +416,8 @@ class GPT2PipelineForwards:
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
logger = logging.get_logger(__name__)
# Preprocess passed in arguments # Preprocess passed in arguments
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -416,7 +443,8 @@ class GPT2PipelineForwards:
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length) token_type_ids = token_type_ids.view(-1, seq_length)
else: else:
assert hidden_states is not None if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1] batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device device = hidden_states.device
@ -478,21 +506,21 @@ class GPT2PipelineForwards:
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values: if past_key_values:
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.') logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None past_key_values = None
if output_attentions: if output_attentions:
logging.warning('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if use_cache: if use_cache:
logging.warning('use_cache=True is not supported for pipeline models at the moment.') logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False use_cache = False
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logging.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False use_cache = False
presents = () if use_cache else None presents = () if use_cache else None
@ -751,6 +779,94 @@ class GPT2PipelineForwards:
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@staticmethod
def gpt2_for_question_answering_forward(
self: 'GPT2ForQuestionAnswering',
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import QuestionAnsweringModelOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod @staticmethod
def gpt2_for_token_classification_forward( def gpt2_for_token_classification_forward(
self: 'GPT2ForTokenClassification', self: 'GPT2ForTokenClassification',
@ -852,6 +968,8 @@ class GPT2PipelineForwards:
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
""" """
from transformers.modeling_outputs import SequenceClassifierOutputWithPast from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
if input_ids is not None: if input_ids is not None:
batch_size, _ = input_ids.shape[:2] batch_size, _ = input_ids.shape[:2]
@ -892,7 +1010,7 @@ class GPT2PipelineForwards:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else: else:
sequence_lengths = -1 sequence_lengths = -1
logging.warning( logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`") "unexpected if using padding tokens in conjunction with `inputs_embeds.`")

View File

@ -1 +1 @@
#from .torchrec import * from .torchrec import *

View File

@ -29,6 +29,17 @@ def data_gen_for_lm():
return data return data
def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
def data_gen_for_token_classification(): def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification', model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config), model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification, data_gen_fn=data_gen_for_token_classification,

View File

@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn() inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask'] input_ids, _ = inputs['input_ids'], inputs['attention_mask']