[pipeline] add pipeline support for all T5 models (#4310)

* complete policy for T5Model & T5ForConditionalGeneration

* modify function signature in forwards

* add forward for T5model

* add forward for T5ForConditionalGeneration

* fix a bug

* fix hidden_states transporting in decoder

* fix the passing of encoder_outputs
pull/4445/head
Baizhou Zhang 2023-07-25 14:45:33 +08:00 committed by Hongxin Liu
parent d0807122e2
commit 083d7da33d
3 changed files with 388 additions and 19 deletions

View File

@ -1,11 +1,15 @@
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import warnings
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
from transformers.utils import logging
@ -198,14 +202,13 @@ class T5PipelineForwards:
if use_cache is False or use_cache is None:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# print(stage, len(layer_outputs), present_key_value_state.shape)
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
if in_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
@ -238,6 +241,313 @@ class T5PipelineForwards:
'encoder_decoder_position_bias': encoder_decoder_position_bias
}
@staticmethod
def t5_model_forward(
self: T5Model,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward.
# Please refer to original code of transformers for more details.
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__)
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
in_decoder = stage_manager.stage >= decoder_starting_stage
# Stage is in encoder, directly return the output of t5_stack_forward
if not in_decoder:
encoder_outputs = T5PipelineForwards.t5_stack_forward(
self.encoder,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_outputs': encoder_outputs}
else:
return encoder_outputs
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
if encoder_outputs is None:
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
encoder_hidden_states = encoder_outputs[0]
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
# Decode
decoder_outputs = T5PipelineForwards.t5_stack_forward(
self.decoder,
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
return decoder_outputs
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@staticmethod
def t5_for_conditional_generation_forward(
self: T5ForConditionalGeneration,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward.
# Please refer to original code of transformers for more details.
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__)
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
in_decoder = stage_manager.stage >= decoder_starting_stage
# Stage is in encoder, directly return the output of t5_stack_forward
if not in_decoder:
encoder_outputs = T5PipelineForwards.t5_stack_forward(
self.encoder,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_outputs': encoder_outputs}
else:
return encoder_outputs
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
if encoder_outputs is None:
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
encoder_hidden_states = encoder_outputs[0]
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
# Decode
decoder_outputs = T5PipelineForwards.t5_stack_forward(
self.decoder,
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
return decoder_outputs
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@staticmethod
def t5_encoder_model_forward(
self: T5EncoderModel,

View File

@ -293,21 +293,42 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
base_policy = super().module_policy()
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=base_policy,
policy=policy,
target_key=T5Model)
return base_policy
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
@ -318,6 +339,9 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5ForConditionalGeneration
@ -335,8 +359,38 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
],
policy=policy,
target_key=T5ForConditionalGeneration)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
shared_params = []
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
shared_params.append({
0: module.shared.weight,
decoder_starting_stage: module.decoder.embed_tokens.weight
})
if id(module.lm_head.weight) == id(module.shared.weight):
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
return shared_params
return []
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
@ -382,7 +436,7 @@ class T5EncoderPolicy(T5BasePolicy):
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)

View File

@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != 'transformers_t5_encoder_model':
continue
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
stage = stage_manager.stage
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
in_decoder = stage >= decoder_starting_stage
if not at_first_stage:
# change inputs if not the first stage
@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
inputs['hidden_states'] = hidden_states
inputs['position_bias'] = position_bias
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
if in_decoder:
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
inputs['encoder_outputs'] = (encoder_output_states,)
sharded_model.train()
output = sharded_model(**inputs)
if at_last_stage:
if name != 'transformers_t5_for_conditional_generation':
assert output[0].shape == hidden_state_shape
else:
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
assert output.loss is not None
else:
if name != 'transformers_t5_encoder_model' and not in_decoder:
output = output['encoder_outputs']
assert output[0].shape == hidden_state_shape
else:
assert output['hidden_states'].shape == hidden_state_shape
# position_bias information should be passed in T5
assert 'position_bias' in output
assert 'encoder_decoder_position_bias' in output
assert output['position_bias'].shape == position_bias_shape
if in_decoder:
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
torch.cuda.empty_cache()