mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)
* change for transformers loggers * add forward for GPT2ForQuestionAnswering * fix assert * fix torchrec testpull/4445/head
parent
d9be0472ef
commit
2a2eacfaf1
|
@ -68,6 +68,8 @@ _POLICY_LIST = {
|
|||
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import logging
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -298,6 +296,33 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
return self.model
|
||||
|
||||
|
||||
# GPT2ForQuestionAnswering
|
||||
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
|
||||
|
||||
module_policy = super().module_policy()
|
||||
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
|
||||
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
|
||||
policy=module_policy)
|
||||
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''No shared_params in gpt2 for QA.'''
|
||||
return []
|
||||
|
||||
|
||||
# GPT2ForTokenClassification
|
||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||
|
||||
|
@ -391,6 +416,8 @@ class GPT2PipelineForwards:
|
|||
# Please refer to original code of transformers for more details.
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
@ -416,7 +443,8 @@ class GPT2PipelineForwards:
|
|||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||
else:
|
||||
assert hidden_states is not None
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
@ -478,21 +506,21 @@ class GPT2PipelineForwards:
|
|||
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logging.warning('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logging.warning('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
use_cache = False
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logging.warning(
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
presents = () if use_cache else None
|
||||
|
@ -751,6 +779,94 @@ class GPT2PipelineForwards:
|
|||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_question_answering_forward(
|
||||
self: 'GPT2ForQuestionAnswering',
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_token_classification_forward(
|
||||
self: 'GPT2ForTokenClassification',
|
||||
|
@ -852,6 +968,8 @@ class GPT2PipelineForwards:
|
|||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, _ = input_ids.shape[:2]
|
||||
|
@ -892,7 +1010,7 @@ class GPT2PipelineForwards:
|
|||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logging.warning(
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
#from .torchrec import *
|
||||
from .torchrec import *
|
||||
|
|
|
@ -29,6 +29,17 @@ def data_gen_for_lm():
|
|||
return data
|
||||
|
||||
|
||||
def data_gen_for_question_answering():
|
||||
# question answering data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
|
@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
|
|||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
|
|
|
@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
|
|
Loading…
Reference in New Issue