mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)
* change for transformers loggers * add forward for GPT2ForQuestionAnswering * fix assert * fix torchrec testpull/4445/head
parent
d9be0472ef
commit
2a2eacfaf1
|
@ -68,6 +68,8 @@ _POLICY_LIST = {
|
||||||
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
|
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
|
||||||
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
|
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
|
||||||
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
|
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
|
||||||
|
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
|
||||||
|
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
|
||||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
|
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
|
||||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
||||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import logging
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -298,6 +296,33 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
# GPT2ForQuestionAnswering
|
||||||
|
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
|
||||||
|
|
||||||
|
module_policy = super().module_policy()
|
||||||
|
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
|
||||||
|
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
|
||||||
|
policy=module_policy)
|
||||||
|
|
||||||
|
return module_policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.qa_outputs)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
'''No shared_params in gpt2 for QA.'''
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# GPT2ForTokenClassification
|
# GPT2ForTokenClassification
|
||||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||||
|
|
||||||
|
@ -391,6 +416,8 @@ class GPT2PipelineForwards:
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.utils import logging
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Preprocess passed in arguments
|
# Preprocess passed in arguments
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
@ -416,7 +443,8 @@ class GPT2PipelineForwards:
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
assert hidden_states is not None
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||||
input_shape = hidden_states.size()[:-1]
|
input_shape = hidden_states.size()[:-1]
|
||||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
@ -478,21 +506,21 @@ class GPT2PipelineForwards:
|
||||||
|
|
||||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logging.warning('output_attentions=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logging.warning('use_cache=True is not supported for pipeline models at the moment.')
|
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logging.warning(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
|
@ -751,6 +779,94 @@ class GPT2PipelineForwards:
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gpt2_for_question_answering_forward(
|
||||||
|
self: 'GPT2ForQuestionAnswering',
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
|
||||||
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
|
||||||
|
# Please refer to original code of transformers for more details.
|
||||||
|
"""
|
||||||
|
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index)
|
||||||
|
|
||||||
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return {'hidden_states': outputs['hidden_states']}
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gpt2_for_token_classification_forward(
|
def gpt2_for_token_classification_forward(
|
||||||
self: 'GPT2ForTokenClassification',
|
self: 'GPT2ForTokenClassification',
|
||||||
|
@ -852,6 +968,8 @@ class GPT2PipelineForwards:
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
"""
|
"""
|
||||||
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
||||||
|
from transformers.utils import logging
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
batch_size, _ = input_ids.shape[:2]
|
batch_size, _ = input_ids.shape[:2]
|
||||||
|
@ -892,7 +1010,7 @@ class GPT2PipelineForwards:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logging.warning(
|
logger.warning_once(
|
||||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
#from .torchrec import *
|
from .torchrec import *
|
||||||
|
|
|
@ -29,6 +29,17 @@ def data_gen_for_lm():
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_question_answering():
|
||||||
|
# question answering data gen
|
||||||
|
# `labels` is the type not the token id for token classification, 0 or 1
|
||||||
|
data = data_gen()
|
||||||
|
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||||
|
data['start_positions'] = start_positions
|
||||||
|
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||||
|
data['end_positions'] = end_positions
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def data_gen_for_token_classification():
|
def data_gen_for_token_classification():
|
||||||
# token classification data gen
|
# token classification data gen
|
||||||
# `labels` is the type not the token id for token classification, 0 or 1
|
# `labels` is the type not the token id for token classification, 0 or 1
|
||||||
|
@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True))
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||||
|
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||||
|
data_gen_fn=data_gen_for_question_answering,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
||||||
data_gen_fn=data_gen_for_token_classification,
|
data_gen_fn=data_gen_for_token_classification,
|
||||||
|
|
|
@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||||
|
|
||||||
inputs = data_gen_fn()
|
inputs = data_gen_fn()
|
||||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||||
|
|
Loading…
Reference in New Issue