mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
parent
411cf1d2db
commit
da3cef27ad
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -277,9 +278,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(
|
outputs = BertPipelineForwards.bert_model_forward(
|
||||||
self.bert,
|
self.bert,
|
||||||
|
@ -387,9 +385,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(
|
outputs = BertPipelineForwards.bert_model_forward(
|
||||||
self.bert,
|
self.bert,
|
||||||
|
@ -478,9 +473,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(
|
outputs = BertPipelineForwards.bert_model_forward(
|
||||||
self.bert,
|
self.bert,
|
||||||
|
@ -579,16 +571,15 @@ class BertPipelineForwards:
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("next_sentence_label")
|
labels = kwargs.pop("next_sentence_label")
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -661,10 +652,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
outputs = BertPipelineForwards.bert_model_forward(self.bert,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -753,10 +740,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(
|
outputs = BertPipelineForwards.bert_model_forward(
|
||||||
self.bert,
|
self.bert,
|
||||||
|
@ -832,10 +815,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# in our pipeline design,input ids are copied for every stage and shouldn't be none
|
# in our pipeline design,input ids are copied for every stage and shouldn't be none
|
||||||
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
|
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
|
||||||
|
@ -928,10 +907,6 @@ class BertPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = BertPipelineForwards.bert_model_forward(
|
outputs = BertPipelineForwards.bert_model_forward(
|
||||||
self.bert,
|
self.bert,
|
||||||
|
|
|
@ -313,9 +313,6 @@ class BloomPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -411,9 +408,6 @@ class BloomPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
||||||
self.transformer,
|
self.transformer,
|
||||||
|
@ -537,9 +531,6 @@ class BloomPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
||||||
self.transformer,
|
self.transformer,
|
||||||
|
@ -626,9 +617,6 @@ class BloomPipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if return_dict:
|
|
||||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
|
||||||
return_dict = False
|
|
||||||
|
|
||||||
outputs = BloomPipelineForwards.bloom_model_forward(
|
outputs = BloomPipelineForwards.bloom_model_forward(
|
||||||
self.transformer,
|
self.transformer,
|
||||||
|
|
|
@ -52,6 +52,8 @@ class GPT2PipelineForwards:
|
||||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Preprocess passed in arguments
|
# Preprocess passed in arguments
|
||||||
|
|
|
@ -8,6 +8,18 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.opt.modeling_opt import (
|
||||||
|
OPTForCausalLM,
|
||||||
|
OPTForQuestionAnswering,
|
||||||
|
OPTForSequenceClassification,
|
||||||
|
OPTModel,
|
||||||
|
)
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
@ -317,7 +329,7 @@ class OPTPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def opt_model_forward(
|
def opt_model_forward(
|
||||||
self: 'OPTModel',
|
self: OPTModel,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
@ -330,7 +342,7 @@ class OPTPipelineForwards:
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
) -> Union[Tuple, 'BaseModelOutputWithPast']:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
'''
|
'''
|
||||||
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
|
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
|
||||||
'''
|
'''
|
||||||
|
@ -506,7 +518,7 @@ class OPTPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def opt_for_causal_lm_forward(
|
def opt_for_causal_lm_forward(
|
||||||
self: 'OPTForCausalLM',
|
self: OPTForCausalLM,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
@ -520,7 +532,7 @@ class OPTPipelineForwards:
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
) -> Union[Tuple, 'CausalLMOutputWithPast']:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
@ -646,7 +658,7 @@ class OPTPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def opt_for_sequence_classification_forward(
|
def opt_for_sequence_classification_forward(
|
||||||
self: 'OPTForSequenceClassification',
|
self: OPTForSequenceClassification,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -660,7 +672,7 @@ class OPTPipelineForwards:
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
@ -746,7 +758,7 @@ class OPTPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def opt_for_question_answering_forward(
|
def opt_for_question_answering_forward(
|
||||||
self: 'OPTForQuestionAnswering',
|
self: OPTForQuestionAnswering,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -761,7 +773,7 @@ class OPTPipelineForwards:
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any, Callable, Iterator, List, Optional, Tuple
|
from typing import Any, Callable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -100,8 +99,8 @@ class data_loader():
|
||||||
return torch.ones((4, 128), dtype=torch.int).cuda() * 10
|
return torch.ones((4, 128), dtype=torch.int).cuda() * 10
|
||||||
|
|
||||||
|
|
||||||
def loss(x, y):
|
def loss(y, x):
|
||||||
return (x[0].float().mean() - y[0].float().mean())
|
return (y[0].float().mean() - x[0].float().mean())
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_fused_normalization', [False])
|
@parameterize('enable_fused_normalization', [False])
|
||||||
|
@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
|
||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y = model_copy(batch)
|
y = model_copy(batch)
|
||||||
org_loss = loss(batch, y)
|
org_loss = loss(y, batch)
|
||||||
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
|
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
|
||||||
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
|
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
|
||||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
|
|
Loading…
Reference in New Issue