mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] add pipeline support for all T5 models (#4310)
* complete policy for T5Model & T5ForConditionalGeneration * modify function signature in forwards * add forward for T5model * add forward for T5ForConditionalGeneration * fix a bug * fix hidden_states transporting in decoder * fix the passing of encoder_outputspull/4445/head
parent
d0807122e2
commit
083d7da33d
|
@ -1,11 +1,15 @@
|
|||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -198,14 +202,13 @@ class T5PipelineForwards:
|
|||
if use_cache is False or use_cache is None:
|
||||
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
# print(stage, len(layer_outputs), present_key_value_state.shape)
|
||||
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
||||
# (cross-attention position bias), (cross-attention weights)
|
||||
position_bias = layer_outputs[2]
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if in_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
# append next layer key value states
|
||||
if use_cache:
|
||||
|
@ -238,6 +241,313 @@ class T5PipelineForwards:
|
|||
'encoder_decoder_position_bias': encoder_decoder_position_bias
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def t5_model_forward(
|
||||
self: T5Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
||||
num_heads)`.
|
||||
"""
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
use_cache = False
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
in_decoder = stage_manager.stage >= decoder_starting_stage
|
||||
|
||||
# Stage is in encoder, directly return the output of t5_stack_forward
|
||||
if not in_decoder:
|
||||
encoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
# last stage of encoder
|
||||
return {'encoder_outputs': encoder_outputs}
|
||||
else:
|
||||
return encoder_outputs
|
||||
|
||||
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
||||
|
||||
if encoder_outputs is None:
|
||||
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
|
||||
if not at_first_decoder_stage and hidden_states is None:
|
||||
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
||||
|
||||
# Decode
|
||||
decoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.decoder,
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
|
||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||
if not at_last_decoder_stage:
|
||||
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
|
||||
return decoder_outputs
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def t5_for_conditional_generation_forward(
|
||||
self: T5ForConditionalGeneration,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
||||
num_heads)`.
|
||||
"""
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
use_cache = False
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
in_decoder = stage_manager.stage >= decoder_starting_stage
|
||||
|
||||
# Stage is in encoder, directly return the output of t5_stack_forward
|
||||
if not in_decoder:
|
||||
encoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
# last stage of encoder
|
||||
return {'encoder_outputs': encoder_outputs}
|
||||
else:
|
||||
return encoder_outputs
|
||||
|
||||
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
||||
|
||||
if encoder_outputs is None:
|
||||
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
|
||||
if not at_first_decoder_stage and hidden_states is None:
|
||||
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
||||
|
||||
# Decode
|
||||
decoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.decoder,
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
|
||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||
if not at_last_decoder_stage:
|
||||
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
|
||||
return decoder_outputs
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# move labels to correct device to enable PP
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def t5_encoder_model_forward(
|
||||
self: T5EncoderModel,
|
||||
|
|
|
@ -293,21 +293,42 @@ class T5BasePolicy(Policy):
|
|||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
base_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=base_policy,
|
||||
policy=policy,
|
||||
target_key=T5Model)
|
||||
return base_policy
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
return super().get_held_layers()
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
|
||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
|
@ -318,6 +339,9 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
|
@ -335,8 +359,38 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
],
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
|
||||
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
|
||||
policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
|
||||
shared_params = []
|
||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||
shared_params.append({
|
||||
0: module.shared.weight,
|
||||
decoder_starting_stage: module.decoder.embed_tokens.weight
|
||||
})
|
||||
if id(module.lm_head.weight) == id(module.shared.weight):
|
||||
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
|
||||
return shared_params
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
super().postprocess()
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
|
@ -382,7 +436,7 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
|
|
|
@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
if name != 'transformers_t5_encoder_model':
|
||||
continue
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
|
@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
stage = stage_manager.stage
|
||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
||||
in_decoder = stage >= decoder_starting_stage
|
||||
|
||||
if not at_first_stage:
|
||||
# change inputs if not the first stage
|
||||
|
@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
inputs['hidden_states'] = hidden_states
|
||||
inputs['position_bias'] = position_bias
|
||||
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
|
||||
if in_decoder:
|
||||
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
inputs['encoder_outputs'] = (encoder_output_states,)
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if at_last_stage:
|
||||
if name != 'transformers_t5_for_conditional_generation':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
if name != 'transformers_t5_encoder_model' and not in_decoder:
|
||||
output = output['encoder_outputs']
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
# position_bias information should be passed in T5
|
||||
assert 'position_bias' in output
|
||||
assert 'encoder_decoder_position_bias' in output
|
||||
assert output['position_bias'].shape == position_bias_shape
|
||||
if in_decoder:
|
||||
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
Loading…
Reference in New Issue