[pipeline] add pipeline forward for variants of gpt2 (#4238)

* add forward for GPTLMHeadModel

* add test for gpt_lm

* arranging get_held_layers method

* arrange forward replacement

* add forward for GPT2ForTokenClassification

* add forward for GPT2ForSequenceClassification

* fix test_shard_gpt2.py

* add GPT2DoubleHeadsmodel & fix bugs

* add id checking in get_shared_params
pull/4445/head
Baizhou Zhang 2023-07-17 15:21:51 +08:00 committed by Hongxin Liu
parent 7e4de520e1
commit a14d352088
2 changed files with 529 additions and 66 deletions

View File

@ -1,11 +1,11 @@
import logging import logging
from functools import partial from functools import partial
from types import MethodType from types import MethodType
from typing import Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor, nn
from torch.nn import Module from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -48,6 +48,10 @@ class GPT2Policy(Policy):
suffix="wte", suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D, target_module=col_nn.VocabParallelEmbedding1D,
), ),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
]) ])
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -120,6 +124,45 @@ class GPT2Policy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == 'GPT2Model':
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPT2Model':
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
# GPT2Model # GPT2Model
class GPT2ModelPolicy(GPT2Policy): class GPT2ModelPolicy(GPT2Policy):
@ -131,40 +174,16 @@ class GPT2ModelPolicy(GPT2Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager: self.set_pipeline_forward(model_cls=GPT2Model,
# set None as default new_forward=GPT2PipelineForwards.gpt2_model_forward,
stage_manager = self.pipeline_stage_manager policy=policy)
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(GPT2PipelineForwards.gpt2_model_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GPT2Model)
return policy return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage.""" return super().get_held_layers()
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
# TODO: check whether there is shared param in gpt2model """No shared params in GPT2Model."""
"""No shared params in gpt2 model."""
return [] return []
@ -188,10 +207,31 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
]) ])
} }
module_policy.update(addon_module) module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy)
return module_policy return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
else:
return []
def postprocess(self): def postprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"} binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -199,7 +239,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
return self.model return self.model
# GPT22DoubleHeadsModel # GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy): class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
@ -219,10 +259,38 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
]) ])
} }
module_policy.update(addon_module) module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
policy=module_policy)
return module_policy return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
else:
return []
def postprocess(self): def postprocess(self):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"} binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2ForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
])
}
module_policy.update(addon_module)
self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForTokenClassification."""
return []
# GPT2ForSequenceClassification # GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy): class GPT2ForSequenceClassificationPolicy(GPT2Policy):
@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
module_policy = super().module_policy()
self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForTokenClassification."""
return []
class GPT2PipelineForwards: class GPT2PipelineForwards:
''' '''
@ -299,8 +416,7 @@ class GPT2PipelineForwards:
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length) token_type_ids = token_type_ids.view(-1, seq_length)
else: else:
if hidden_states is None: assert hidden_states is not None
raise ValueError("hidden_states shouln't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1] batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device device = hidden_states.device
@ -462,3 +578,356 @@ class GPT2PipelineForwards:
else: else:
# always return dict for intermediate stage # always return dict for intermediate stage
return {'hidden_states': hidden_states} return {'hidden_states': hidden_states}
@staticmethod
def gpt2_lmhead_model_forward(
self: 'GPT2LMHeadModel',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@staticmethod
def gpt2_double_heads_model_forward(
self: 'GPT2DoubleHeadsModel',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
mc_labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1]`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward.
Please refer to original code of transformers for more details.
```"""
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
mc_loss = None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
lm_loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits, mc_logits) + outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
loss=lm_loss,
mc_loss=mc_loss,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_token_classification_forward(
self: 'GPT2ForTokenClassification',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import TokenClassifierOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@staticmethod
def gpt2_for_sequence_classification_forward(
self: 'GPT2ForSequenceClassification',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
else:
batch_size, _ = hidden_states.shape[:2]
assert (self.config.pad_token_id is not None
or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}
hidden_states = outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logging.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -5,15 +5,9 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import ( from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
pass pass
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False]) @parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False]) @parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_gpt2 #TODO: merge this into test_shard_gpt2
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != "transformers_gpt":
continue
inputs = data_gen_fn() inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
batch_size, seq_len = input_ids.shape
hidden_size = 768
hidden_state_shape = (batch_size, seq_len, hidden_size)
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, if not stage_manager.is_first_stage():
enable_tensor_parallelism, use_lazy_init) # change inputs if not the first stage
org_model.train()
org_output = org_model(**inputs)
hidden_state_shape = org_output['last_hidden_state'].shape
if stage_manager.is_first_stage():
output = sharded_model(**inputs)
assert output['hidden_states'].shape == hidden_state_shape
else:
attention_mask = inputs['attention_mask']
hidden_states = torch.zeros(*hidden_state_shape).cuda() hidden_states = torch.zeros(*hidden_state_shape).cuda()
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) inputs['input_ids'] = None
if stage_manager.is_last_stage(): inputs['hidden_states'] = hidden_states
assert output['last_hidden_state'].shape == hidden_state_shape
else: _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
assert output['hidden_states'].shape == hidden_state_shape enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_gpt':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache() torch.cuda.empty_cache()