[pipeline] add pipeline forward for variants of gpt2 (#4238)

* add forward for GPTLMHeadModel

* add test for gpt_lm

* arranging get_held_layers method

* arrange forward replacement

* add forward for GPT2ForTokenClassification

* add forward for GPT2ForSequenceClassification

* fix test_shard_gpt2.py

* add GPT2DoubleHeadsmodel & fix bugs

* add id checking in get_shared_params
pull/4445/head
Baizhou Zhang 2023-07-17 15:21:51 +08:00 committed by Hongxin Liu
parent 7e4de520e1
commit a14d352088
2 changed files with 529 additions and 66 deletions

View File

@ -1,11 +1,11 @@
import logging
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -48,6 +48,10 @@ class GPT2Policy(Policy):
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -120,6 +124,45 @@ class GPT2Policy(Policy):
def postprocess(self):
return self.model
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'GPT2Model':
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPT2Model':
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
@ -131,40 +174,16 @@ class GPT2ModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(GPT2PipelineForwards.gpt2_model_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GPT2Model)
self.set_pipeline_forward(model_cls=GPT2Model,
new_forward=GPT2PipelineForwards.gpt2_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# TODO: check whether there is shared param in gpt2model
"""No shared params in gpt2 model."""
"""No shared params in GPT2Model."""
return []
@ -188,10 +207,31 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
])
}
module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
else:
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
@ -199,7 +239,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
return self.model
# GPT22DoubleHeadsModel
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
@ -219,10 +259,38 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
])
}
module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
else:
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2ForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
])
}
module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForTokenClassification."""
return []
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
module_policy = super().module_policy()
self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForTokenClassification."""
return []
class GPT2PipelineForwards:
'''
@ -299,8 +416,7 @@ class GPT2PipelineForwards:
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
if hidden_states is None:
raise ValueError("hidden_states shouln't be None for stages other than the first stage.")
assert hidden_states is not None
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
@ -462,3 +578,356 @@ class GPT2PipelineForwards:
else:
# always return dict for intermediate stage
return {'hidden_states': hidden_states}
@staticmethod
def gpt2_lmhead_model_forward(
self: 'GPT2LMHeadModel',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@staticmethod
def gpt2_double_heads_model_forward(
self: 'GPT2DoubleHeadsModel',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
mc_labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1]`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward.
Please refer to original code of transformers for more details.
```"""
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
mc_loss = None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
lm_loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits, mc_logits) + outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
loss=lm_loss,
mc_loss=mc_loss,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_token_classification_forward(
self: 'GPT2ForTokenClassification',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import TokenClassifierOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_sequence_classification_forward(
self: 'GPT2ForSequenceClassification',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
else:
batch_size, _ = hidden_states.shape[:2]
assert (self.config.pad_token_id is not None
or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logging.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -5,15 +5,9 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
pass
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_gpt2
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_gpt":
continue
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
batch_size, seq_len = input_ids.shape
hidden_size = 768
hidden_state_shape = (batch_size, seq_len, hidden_size)
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
org_model.train()
org_output = org_model(**inputs)
hidden_state_shape = org_output['last_hidden_state'].shape
if stage_manager.is_first_stage():
output = sharded_model(**inputs)
assert output['hidden_states'].shape == hidden_state_shape
else:
attention_mask = inputs['attention_mask']
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask)
if stage_manager.is_last_stage():
assert output['last_hidden_state'].shape == hidden_state_shape
else:
assert output['hidden_states'].shape == hidden_state_shape
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_gpt':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()