[pipeline] fix return_dict/fix pure_pipeline_test (#4331)

pull/4445/head
Baizhou Zhang 2023-07-27 14:53:20 +08:00 committed by Hongxin Liu
parent 411cf1d2db
commit da3cef27ad
5 changed files with 29 additions and 53 deletions

View File

@ -1,3 +1,4 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
@ -277,9 +278,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
@ -387,9 +385,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
@ -478,9 +473,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
@ -579,16 +571,15 @@ class BertPipelineForwards:
FutureWarning, FutureWarning,
) )
labels = kwargs.pop("next_sentence_label") labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions: if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert, outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids, input_ids,
@ -661,10 +652,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert, outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids, input_ids,
@ -753,10 +740,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
@ -832,10 +815,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none # in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
@ -928,10 +907,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,

View File

@ -313,9 +313,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
input_ids, input_ids,
@ -411,9 +408,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward( transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,
@ -537,9 +531,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward( transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,
@ -626,9 +617,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BloomPipelineForwards.bloom_model_forward( outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,

View File

@ -52,6 +52,8 @@ class GPT2PipelineForwards:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Preprocess passed in arguments # Preprocess passed in arguments

View File

@ -8,6 +8,18 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
@ -317,7 +329,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_model_forward( def opt_model_forward(
self: 'OPTModel', self: OPTModel,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'BaseModelOutputWithPast']: ) -> Union[Tuple, BaseModelOutputWithPast]:
''' '''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
''' '''
@ -506,7 +518,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_causal_lm_forward( def opt_for_causal_lm_forward(
self: 'OPTForCausalLM', self: OPTForCausalLM,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'CausalLMOutputWithPast']: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -646,7 +658,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_sequence_classification_forward( def opt_for_sequence_classification_forward(
self: 'OPTForSequenceClassification', self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -746,7 +758,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_question_answering_forward( def opt_for_question_answering_forward(
self: 'OPTForQuestionAnswering', self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'QuestionAnsweringModelOutput']: ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.

View File

@ -1,6 +1,5 @@
import copy import copy
import random import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple from typing import Any, Callable, Iterator, List, Optional, Tuple
import numpy as np import numpy as np
@ -100,8 +99,8 @@ class data_loader():
return torch.ones((4, 128), dtype=torch.int).cuda() * 10 return torch.ones((4, 128), dtype=torch.int).cuda() * 10
def loss(x, y): def loss(y, x):
return (x[0].float().mean() - y[0].float().mean()) return (y[0].float().mean() - x[0].float().mean())
@parameterize('enable_fused_normalization', [False]) @parameterize('enable_fused_normalization', [False])
@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch = next(data_iter) batch = next(data_iter)
with torch.no_grad(): with torch.no_grad():
y = model_copy(batch) y = model_copy(batch)
org_loss = loss(batch, y) org_loss = loss(y, batch)
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,