[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)

* modify t5 policy & add test

* pipeline stage distribution for t5

* complete t5 base policy

* t5 stack: halfway

* modify gpt2 pipeline test

* complete pipeline forward for T5Stack/T5EncoderModel

* fix docstring

* move t5 util tests to test_pipeline
pull/4445/head
Baizhou Zhang 2023-07-21 16:23:04 +08:00 committed by Hongxin Liu
parent 18ebcf406a
commit 36e546b2cc
6 changed files with 604 additions and 21 deletions

View File

@ -0,0 +1,279 @@
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.checkpoint import checkpoint
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
class T5PipelineForwards:
'''
This class serves as a micro library for forward function substitution of
T5 models under pipeline setting.
'''
@staticmethod
def t5_stack_forward(
self: T5Stack,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward.
# Please refer to original code of transformers for more details.
logger = logging.get_logger(__name__)
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
if use_cache is True:
if not in_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
stage = stage_manager.stage
in_decoder = self.is_decoder
if in_decoder != (stage >= decoder_starting_stage):
raise ValueError("Config in T5Stack is not aligned with pipeline setting.")
# at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds
# at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
# Process inputs if at the first stage of encoder/decoder.
if at_first_stage:
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if in_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if in_decoder else ""
raise ValueError(
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
device = inputs_embeds.device
hidden_states = self.dropout(inputs_embeds)
else:
if hidden_states is None:
raise ValueError(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
past_key_value = past_key_values[i]
layer_module = self.block[i]
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
torch.cuda.set_device(hidden_states.device)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False or use_cache is None:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# print(stage, len(layer_outputs), present_key_value_state.shape)
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
# last layer
if at_last_stage:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
if not return_dict:
return tuple(v for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
else:
return {
'hidden_states': hidden_states,
'position_bias': position_bias,
'encoder_decoder_position_bias': encoder_decoder_position_bias
}
@staticmethod
def t5_encoder_model_forward(
self: T5EncoderModel,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
This function is modified on the basis of transformers.models.t5.modeling_gpt2.T5EncoderModel.forward.
Please refer to original code of transformers for more details.
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = T5PipelineForwards.t5_stack_forward(self.encoder,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
return outputs

View File

@ -1,3 +1,8 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
from torch import Tensor, nn
from colossalai.shardformer.layer import (
DropoutForParallelInput,
Embedding1D,
@ -8,9 +13,11 @@ from colossalai.shardformer.layer import (
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from ..modeling.t5 import T5PipelineForwards
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
class T5BasePolicy(Policy):
@ -166,6 +173,123 @@ class T5BasePolicy(Policy):
def postprocess(self):
return self.model
@staticmethod
def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for T5 must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = 0
optimal_diff = 2**31 - 1
for i in range(1, num_stages):
attempt = objective(i)
if attempt < optimal_diff:
num_encoder_stages = i
optimal_diff = attempt
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_t5_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
stage_manager = self.pipeline_stage_manager
model = self.model
encoder = self.model.encoder
decoder = self.model.__dict__.get('decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
if stage_manager.is_first_stage():
held_layers.append(model.shared)
held_layers.append(encoder.embed_tokens)
held_layers.append(encoder.dropout)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.final_layer_norm)
held_layers.append(encoder.dropout)
held_layers.extend(encoder.block[start_idx:end_idx])
else:
# current stage is in t5's decoder
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.dropout)
if stage_manager.is_last_stage():
held_layers.append(decoder.final_layer_norm)
held_layers.append(decoder.dropout)
held_layers.extend(decoder.block[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
encoder = self.model.encoder
decoder = self.model.__dict__.get('decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
class T5ModelPolicy(T5BasePolicy):
@ -182,6 +306,15 @@ class T5ModelPolicy(T5BasePolicy):
target_key=T5Model)
return base_policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
return self.model
class T5ForConditionalGenerationPolicy(T5BasePolicy):
@ -204,19 +337,55 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
target_key=T5ForConditionalGeneration)
return policy
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {
"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
return self.model
class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5EncoderModel
base_policy = super().module_policy()
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=base_policy,
policy=policy,
target_key=T5EncoderModel)
return base_policy
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5EncoderModel,
new_forward=T5PipelineForwards.t5_encoder_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
return self.model

View File

@ -62,10 +62,8 @@ output_transform_fn = lambda x: x
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss
config = transformers.GPT2Config(
n_layer=2,
config = transformers.GPT2Config(n_layer=2,
n_head=4,
#n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,

View File

@ -0,0 +1,39 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@ -29,9 +29,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
input_ids = inputs['input_ids']
batch_size, seq_len = input_ids.shape
hidden_size = 768
hidden_size = sharded_model.config.n_embd
hidden_state_shape = (batch_size, seq_len, hidden_size)
if not stage_manager.is_first_stage():
@ -40,12 +42,12 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_gpt':
if name == 'transformers_gpt':
assert output[0].shape == hidden_state_shape
else:
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape

View File

@ -0,0 +1,96 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_t5.py
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != 'transformers_t5_encoder_model':
continue
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
batch_size, seq_len = input_ids.shape
hidden_size = sharded_model.config.d_model
num_heads = sharded_model.config.num_heads
hidden_state_shape = (batch_size, seq_len, hidden_size)
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
num_encoder_layers = len(sharded_model.encoder.block)
decoder = sharded_model.__dict__.get('decoder', None)
num_decoder_layers = len(decoder.block) if decoder else 0
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
stage = stage_manager.stage
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
if not at_first_stage:
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
position_bias = torch.zeros(*position_bias_shape).cuda()
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
inputs['position_bias'] = position_bias
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
sharded_model.train()
output = sharded_model(**inputs)
if at_last_stage:
if name != 'transformers_t5_for_conditional_generation':
assert output[0].shape == hidden_state_shape
else:
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
# position_bias information should be passed in T5
assert 'position_bias' in output
assert 'encoder_decoder_position_bias' in output
torch.cuda.empty_cache()
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
spawn(check_t5, 4)
if __name__ == "__main__":
test_t5()