mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] add pipeline forward for variants of gpt2 (#4238)
* add forward for GPTLMHeadModel * add test for gpt_lm * arranging get_held_layers method * arrange forward replacement * add forward for GPT2ForTokenClassification * add forward for GPT2ForSequenceClassification * fix test_shard_gpt2.py * add GPT2DoubleHeadsmodel & fix bugs * add id checking in get_shared_paramspull/4445/head
parent
7e4de520e1
commit
a14d352088
|
@ -1,11 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from torch.nn import Module
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -48,6 +48,10 @@ class GPT2Policy(Policy):
|
||||||
suffix="wte",
|
suffix="wte",
|
||||||
target_module=col_nn.VocabParallelEmbedding1D,
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
),
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="drop",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
])
|
])
|
||||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
@ -120,6 +124,45 @@ class GPT2Policy(Policy):
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == 'GPT2Model':
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.transformer
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.wte)
|
||||||
|
held_layers.append(module.wpe)
|
||||||
|
held_layers.append(module.drop)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.ln_f)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == 'GPT2Model':
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.transformer
|
||||||
|
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement,
|
||||||
|
policy=policy,
|
||||||
|
target_key=model_cls)
|
||||||
|
|
||||||
|
|
||||||
# GPT2Model
|
# GPT2Model
|
||||||
class GPT2ModelPolicy(GPT2Policy):
|
class GPT2ModelPolicy(GPT2Policy):
|
||||||
|
@ -131,40 +174,16 @@ class GPT2ModelPolicy(GPT2Policy):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.pipeline_stage_manager:
|
self.set_pipeline_forward(model_cls=GPT2Model,
|
||||||
# set None as default
|
new_forward=GPT2PipelineForwards.gpt2_model_forward,
|
||||||
stage_manager = self.pipeline_stage_manager
|
policy=policy)
|
||||||
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
|
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
||||||
method_replacement = {
|
|
||||||
'forward':
|
|
||||||
partial(GPT2PipelineForwards.gpt2_model_forward,
|
|
||||||
stage_manager=stage_manager,
|
|
||||||
stage_index=stage_index)
|
|
||||||
}
|
|
||||||
self.append_or_create_method_replacement(description=method_replacement,
|
|
||||||
policy=policy,
|
|
||||||
target_key=GPT2Model)
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
"""Get pipeline layers for current stage."""
|
return super().get_held_layers()
|
||||||
module = self.model
|
|
||||||
stage_manager = self.pipeline_stage_manager
|
|
||||||
held_layers = []
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
|
||||||
if stage_manager.is_first_stage():
|
|
||||||
held_layers.append(module.wte)
|
|
||||||
held_layers.append(module.wpe)
|
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
|
||||||
if stage_manager.is_last_stage():
|
|
||||||
held_layers.append(module.ln_f)
|
|
||||||
return held_layers
|
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
# TODO: check whether there is shared param in gpt2model
|
"""No shared params in GPT2Model."""
|
||||||
"""No shared params in gpt2 model."""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,10 +207,31 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
|
|
||||||
|
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
|
||||||
|
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||||
|
policy=module_policy)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
'''The weights of wte and lm_head are shared.'''
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||||
|
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||||
|
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism \
|
||||||
|
and self.pipeline_stage_manager is None:
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
|
@ -199,7 +239,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
# GPT22DoubleHeadsModel
|
# GPT2DoubleHeadsModel
|
||||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -219,10 +259,38 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
|
|
||||||
|
self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
|
||||||
|
new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
|
||||||
|
policy=module_policy)
|
||||||
|
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
multiple_choice_head = self.model.multiple_choice_head
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
held_layers.append(multiple_choice_head.summary)
|
||||||
|
held_layers.append(multiple_choice_head.activation)
|
||||||
|
held_layers.append(multiple_choice_head.first_dropout)
|
||||||
|
held_layers.append(multiple_choice_head.last_dropout)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
'''The weights of wte and lm_head are shared.'''
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||||
|
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||||
|
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism \
|
||||||
|
and self.pipeline_stage_manager is None:
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
|
@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
|
||||||
|
|
||||||
|
module_policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
addon_module = {
|
||||||
|
GPT2ForTokenClassification:
|
||||||
|
ModulePolicyDescription(sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
|
||||||
|
])
|
||||||
|
}
|
||||||
|
module_policy.update(addon_module)
|
||||||
|
|
||||||
|
self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
|
||||||
|
new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
|
||||||
|
policy=module_policy)
|
||||||
|
return module_policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.dropout)
|
||||||
|
held_layers.append(self.model.classifier)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in GPT2ForTokenClassification."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# GPT2ForSequenceClassification
|
# GPT2ForSequenceClassification
|
||||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||||
|
@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
|
||||||
|
|
||||||
|
module_policy = super().module_policy()
|
||||||
|
self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
|
||||||
|
new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
|
||||||
|
policy=module_policy)
|
||||||
|
return module_policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.score)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in GPT2ForTokenClassification."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class GPT2PipelineForwards:
|
class GPT2PipelineForwards:
|
||||||
'''
|
'''
|
||||||
|
@ -299,8 +416,7 @@ class GPT2PipelineForwards:
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
if hidden_states is None:
|
assert hidden_states is not None
|
||||||
raise ValueError("hidden_states shouln't be None for stages other than the first stage.")
|
|
||||||
input_shape = hidden_states.size()[:-1]
|
input_shape = hidden_states.size()[:-1]
|
||||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
@ -462,3 +578,356 @@ class GPT2PipelineForwards:
|
||||||
else:
|
else:
|
||||||
# always return dict for intermediate stage
|
# always return dict for intermediate stage
|
||||||
return {'hidden_states': hidden_states}
|
return {'hidden_states': hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gpt2_lmhead_model_forward(
|
||||||
|
self: 'GPT2LMHeadModel',
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
|
||||||
|
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||||
|
Please refer to original code of transformers for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index)
|
||||||
|
|
||||||
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return {'hidden_states': outputs['hidden_states']}
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# move labels to correct device to enable model parallelism
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gpt2_double_heads_model_forward(
|
||||||
|
self: 'GPT2DoubleHeadsModel',
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
mc_token_ids: Optional[torch.LongTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
mc_labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']:
|
||||||
|
r"""
|
||||||
|
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
||||||
|
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
||||||
|
1]`.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
|
`labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
|
||||||
|
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
|
||||||
|
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
|
||||||
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
||||||
|
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
|
||||||
|
|
||||||
|
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward.
|
||||||
|
Please refer to original code of transformers for more details.
|
||||||
|
```"""
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index)
|
||||||
|
|
||||||
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return {'hidden_states': outputs['hidden_states']}
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
||||||
|
|
||||||
|
mc_loss = None
|
||||||
|
if mc_labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
||||||
|
lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits, mc_logits) + outputs[1:]
|
||||||
|
if mc_loss is not None:
|
||||||
|
output = (mc_loss,) + output
|
||||||
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
|
return GPT2DoubleHeadsModelOutput(
|
||||||
|
loss=lm_loss,
|
||||||
|
mc_loss=mc_loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
mc_logits=mc_logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gpt2_for_token_classification_forward(
|
||||||
|
self: 'GPT2ForTokenClassification',
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward.
|
||||||
|
# Please refer to original code of transformers for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import TokenClassifierOutput
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index)
|
||||||
|
|
||||||
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return {'hidden_states': outputs['hidden_states']}
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gpt2_for_sequence_classification_forward(
|
||||||
|
self: 'GPT2ForSequenceClassification',
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
|
||||||
|
# Please refer to original code of transformers for more details.
|
||||||
|
"""
|
||||||
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, _ = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, _ = hidden_states.shape[:2]
|
||||||
|
assert (self.config.pad_token_id is not None
|
||||||
|
or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index)
|
||||||
|
|
||||||
|
# If not at the last stage, return hidden_states as in GPT2Model
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return {'hidden_states': outputs['hidden_states']}
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
sequence_lengths = -1
|
||||||
|
else:
|
||||||
|
if input_ids is not None:
|
||||||
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
|
else:
|
||||||
|
sequence_lengths = -1
|
||||||
|
logging.warning(
|
||||||
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
|
@ -5,15 +5,9 @@ import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.testing import (
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
assert_hf_output_close,
|
|
||||||
clear_cache_before_run,
|
|
||||||
parameterize,
|
|
||||||
rerun_if_address_is_in_use,
|
|
||||||
spawn,
|
|
||||||
)
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||||
|
|
||||||
|
|
||||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_fused_normalization', [False])
|
|
||||||
@parameterize('enable_tensor_parallelism', [False])
|
@parameterize('enable_tensor_parallelism', [False])
|
||||||
|
@parameterize('enable_fused_normalization', [False])
|
||||||
@parameterize('use_lazy_init', [False])
|
@parameterize('use_lazy_init', [False])
|
||||||
#TODO: merge this into test_shard_gpt2
|
#TODO: merge this into test_shard_gpt2
|
||||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
|
@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||||
if name != "transformers_gpt":
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs = data_gen_fn()
|
inputs = data_gen_fn()
|
||||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
|
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
hidden_size = 768
|
||||||
|
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||||
|
|
||||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
if not stage_manager.is_first_stage():
|
||||||
enable_tensor_parallelism, use_lazy_init)
|
# change inputs if not the first stage
|
||||||
org_model.train()
|
|
||||||
org_output = org_model(**inputs)
|
|
||||||
hidden_state_shape = org_output['last_hidden_state'].shape
|
|
||||||
|
|
||||||
if stage_manager.is_first_stage():
|
|
||||||
output = sharded_model(**inputs)
|
|
||||||
assert output['hidden_states'].shape == hidden_state_shape
|
|
||||||
else:
|
|
||||||
attention_mask = inputs['attention_mask']
|
|
||||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||||
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask)
|
inputs['input_ids'] = None
|
||||||
if stage_manager.is_last_stage():
|
inputs['hidden_states'] = hidden_states
|
||||||
assert output['last_hidden_state'].shape == hidden_state_shape
|
|
||||||
else:
|
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||||
assert output['hidden_states'].shape == hidden_state_shape
|
enable_tensor_parallelism, use_lazy_init)
|
||||||
|
sharded_model.train()
|
||||||
|
output = sharded_model(**inputs)
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if name != 'transformers_gpt':
|
||||||
|
assert output.loss is not None
|
||||||
|
else:
|
||||||
|
assert output['hidden_states'].shape == hidden_state_shape
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue