From 2a2eacfaf17b17e5bcb4cd334303a1137ebdfb84 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 19 Jul 2023 09:28:27 +0800 Subject: [PATCH] [pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245) * change for transformers loggers * add forward for GPT2ForQuestionAnswering * fix assert * fix torchrec test --- .../shardformer/policies/auto_policy.py | 2 + colossalai/shardformer/policies/gpt2.py | 136 ++++++++++++++++-- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 17 +++ .../test_model/test_shard_gpt2_pipeline.py | 1 - 5 files changed, 147 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ccdb33b2e..b31f1b35f 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -68,6 +68,8 @@ _POLICY_LIST = { PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": + PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5d6f47636..05178895d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,6 +1,4 @@ -import logging from functools import partial -from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -298,6 +296,33 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): return self.model +# GPT2ForQuestionAnswering +class GPT2ForQuestionAnsweringPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering + + module_policy = super().module_policy() + self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''No shared_params in gpt2 for QA.''' + return [] + + # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): @@ -391,6 +416,8 @@ class GPT2PipelineForwards: # Please refer to original code of transformers for more details. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + from transformers.utils import logging + logger = logging.get_logger(__name__) # Preprocess passed in arguments output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -416,7 +443,8 @@ class GPT2PipelineForwards: if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) else: - assert hidden_states is not None + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -478,21 +506,21 @@ class GPT2PipelineForwards: # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None if output_attentions: - logging.warning('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False if output_hidden_states: - logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False if use_cache: - logging.warning('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') use_cache = False if self.gradient_checkpointing and self.training: if use_cache: - logging.warning( + logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False presents = () if use_cache else None @@ -751,6 +779,94 @@ class GPT2PipelineForwards: attentions=outputs.attentions, ) + @staticmethod + def gpt2_for_question_answering_forward( + self: 'GPT2ForQuestionAnswering', + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + """ + from transformers.modeling_outputs import QuestionAnsweringModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + @staticmethod def gpt2_for_token_classification_forward( self: 'GPT2ForTokenClassification', @@ -852,6 +968,8 @@ class GPT2PipelineForwards: # Please refer to original code of transformers for more details. """ from transformers.modeling_outputs import SequenceClassifierOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) if input_ids is not None: batch_size, _ = input_ids.shape[:2] @@ -892,7 +1010,7 @@ class GPT2PipelineForwards: sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 - logging.warning( + logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`") diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449..43952e699 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f9a0888ff..0fbcaa1e2 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -29,6 +29,17 @@ def data_gen_for_lm(): return data +def data_gen_for_question_answering(): + # question answering data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data['start_positions'] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data['end_positions'] = end_positions + return data + + def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 @@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads', output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_question_answering', + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', model_fn=lambda: transformers.GPT2ForTokenClassification(config), data_gen_fn=data_gen_for_token_classification, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index dd439a394..005e3d6f8 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids, _ = inputs['input_ids'], inputs['attention_mask']