mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] add pipeline forward for variants of gpt2 (#4238)
* add forward for GPTLMHeadModel * add test for gpt_lm * arranging get_held_layers method * arrange forward replacement * add forward for GPT2ForTokenClassification * add forward for GPT2ForSequenceClassification * fix test_shard_gpt2.py * add GPT2DoubleHeadsmodel & fix bugs * add id checking in get_shared_paramspull/4445/head
parent
7e4de520e1
commit
a14d352088
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -48,6 +48,10 @@ class GPT2Policy(Policy):
|
|||
suffix="wte",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="drop",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
|
@ -120,6 +124,45 @@ class GPT2Policy(Policy):
|
|||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == 'GPT2Model':
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.transformer
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.wpe)
|
||||
held_layers.append(module.drop)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == 'GPT2Model':
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
||||
|
||||
# GPT2Model
|
||||
class GPT2ModelPolicy(GPT2Policy):
|
||||
|
@ -131,40 +174,16 @@ class GPT2ModelPolicy(GPT2Policy):
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(GPT2PipelineForwards.gpt2_model_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=GPT2Model)
|
||||
self.set_pipeline_forward(model_cls=GPT2Model,
|
||||
new_forward=GPT2PipelineForwards.gpt2_model_forward,
|
||||
policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.wpe)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
return held_layers
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
return super().get_held_layers()
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# TODO: check whether there is shared param in gpt2model
|
||||
"""No shared params in gpt2 model."""
|
||||
"""No shared params in GPT2Model."""
|
||||
return []
|
||||
|
||||
|
||||
|
@ -188,10 +207,31 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
policy=module_policy)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''The weights of wte and lm_head are shared.'''
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||
else:
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism \
|
||||
and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
|
@ -199,7 +239,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
return self.model
|
||||
|
||||
|
||||
# GPT22DoubleHeadsModel
|
||||
# GPT2DoubleHeadsModel
|
||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -219,10 +259,38 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
|
||||
policy=module_policy)
|
||||
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
multiple_choice_head = self.model.multiple_choice_head
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''The weights of wte and lm_head are shared.'''
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||
else:
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism \
|
||||
and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
|
@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2ForTokenClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
|
||||
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
|
||||
policy=module_policy)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in GPT2ForTokenClassification."""
|
||||
return []
|
||||
|
||||
|
||||
# GPT2ForSequenceClassification
|
||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||
|
@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
|
||||
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
|
||||
policy=module_policy)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in GPT2ForTokenClassification."""
|
||||
return []
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
'''
|
||||
|
@ -299,8 +416,7 @@ class GPT2PipelineForwards:
|
|||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouln't be None for stages other than the first stage.")
|
||||
assert hidden_states is not None
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
@ -462,3 +578,356 @@ class GPT2PipelineForwards:
|
|||
else:
|
||||
# always return dict for intermediate stage
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def gpt2_lmhead_model_forward(
|
||||
self: 'GPT2LMHeadModel',
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gpt2_double_heads_model_forward(
|
||||
self: 'GPT2DoubleHeadsModel',
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mc_token_ids: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
mc_labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']:
|
||||
r"""
|
||||
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
||||
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
||||
1]`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
|
||||
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
|
||||
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
||||
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
```"""
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
||||
|
||||
mc_loss = None
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(lm_logits.device)
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits, mc_logits) + outputs[1:]
|
||||
if mc_loss is not None:
|
||||
output = (mc_loss,) + output
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return GPT2DoubleHeadsModelOutput(
|
||||
loss=lm_loss,
|
||||
mc_loss=mc_loss,
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_token_classification_forward(
|
||||
self: 'GPT2ForTokenClassification',
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
|
||||
from transformers.modeling_outputs import TokenClassifierOutput
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_sequence_classification_forward(
|
||||
self: 'GPT2ForSequenceClassification',
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, _ = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, _ = hidden_states.shape[:2]
|
||||
assert (self.config.pad_token_id is not None
|
||||
or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logging.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -5,15 +5,9 @@ import colossalai
|
|||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
pass
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_gpt2
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
|
@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name != "transformers_gpt":
|
||||
continue
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = 768
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
org_model.train()
|
||||
org_output = org_model(**inputs)
|
||||
hidden_state_shape = org_output['last_hidden_state'].shape
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
output = sharded_model(**inputs)
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
else:
|
||||
attention_mask = inputs['attention_mask']
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask)
|
||||
if stage_manager.is_last_stage():
|
||||
assert output['last_hidden_state'].shape == hidden_state_shape
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_gpt':
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
Loading…
Reference in New Issue