diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index df64c93cf..1b3c14d9d 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple import torch @@ -277,9 +278,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -387,9 +385,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -478,9 +473,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -579,16 +571,15 @@ class BertPipelineForwards: FutureWarning, ) labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward(self.bert, input_ids, @@ -661,10 +652,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward(self.bert, input_ids, @@ -753,10 +740,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -832,10 +815,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # in our pipeline design,input ids are copied for every stage and shouldn't be none # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] @@ -928,10 +907,6 @@ class BertPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward( self.bert, diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index fd200665d..76948fc70 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -313,9 +313,6 @@ class BloomPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, input_ids, @@ -411,9 +408,6 @@ class BloomPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, @@ -537,9 +531,6 @@ class BloomPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, @@ -626,9 +617,6 @@ class BloomPipelineForwards: if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 5519d0b30..dc5a81dc9 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -52,6 +52,8 @@ class GPT2PipelineForwards: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + logger = logging.get_logger(__name__) # Preprocess passed in arguments diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 244a0a54e..6fc3a2d31 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -8,6 +8,18 @@ import torch import torch.nn as nn from torch import Tensor, nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.opt.modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, +) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D @@ -317,7 +329,7 @@ class OPTPipelineForwards: @staticmethod def opt_model_forward( - self: 'OPTModel', + self: OPTModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -330,7 +342,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'BaseModelOutputWithPast']: + ) -> Union[Tuple, BaseModelOutputWithPast]: ''' This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward ''' @@ -506,7 +518,7 @@ class OPTPipelineForwards: @staticmethod def opt_for_causal_lm_forward( - self: 'OPTForCausalLM', + self: OPTForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -520,7 +532,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'CausalLMOutputWithPast']: + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -646,7 +658,7 @@ class OPTPipelineForwards: @staticmethod def opt_for_sequence_classification_forward( - self: 'OPTForSequenceClassification', + self: OPTForSequenceClassification, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -660,7 +672,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -746,7 +758,7 @@ class OPTPipelineForwards: @staticmethod def opt_for_question_answering_forward( - self: 'OPTForQuestionAnswering', + self: OPTForQuestionAnswering, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -761,7 +773,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'QuestionAnsweringModelOutput']: + ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index c3cd05095..576e6473b 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -1,6 +1,5 @@ import copy import random -from contextlib import nullcontext from typing import Any, Callable, Iterator, List, Optional, Tuple import numpy as np @@ -100,8 +99,8 @@ class data_loader(): return torch.ones((4, 128), dtype=torch.int).cuda() * 10 -def loss(x, y): - return (x[0].float().mean() - y[0].float().mean()) +def loss(y, x): + return (y[0].float().mean() - x[0].float().mean()) @parameterize('enable_fused_normalization', [False]) @@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la batch = next(data_iter) with torch.no_grad(): y = model_copy(batch) - org_loss = loss(batch, y) + org_loss = loss(y, batch) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,