mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)
* modify t5 policy & add test * pipeline stage distribution for t5 * complete t5 base policy * t5 stack: halfway * modify gpt2 pipeline test * complete pipeline forward for T5Stack/T5EncoderModel * fix docstring * move t5 util tests to test_pipelinepull/4445/head
parent
18ebcf406a
commit
36e546b2cc
|
@ -0,0 +1,279 @@
|
|||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class T5PipelineForwards:
|
||||
'''
|
||||
This class serves as a micro library for forward function substitution of
|
||||
T5 models under pipeline setting.
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def t5_stack_forward(
|
||||
self: T5Stack,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
use_cache = False
|
||||
if use_cache is True:
|
||||
if not in_decoder:
|
||||
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
stage = stage_manager.stage
|
||||
in_decoder = self.is_decoder
|
||||
if in_decoder != (stage >= decoder_starting_stage):
|
||||
raise ValueError("Config in T5Stack is not aligned with pipeline setting.")
|
||||
|
||||
# at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds
|
||||
# at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface
|
||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
||||
|
||||
# Process inputs if at the first stage of encoder/decoder.
|
||||
if at_first_stage:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
err_msg_prefix = "decoder_" if in_decoder else ""
|
||||
raise ValueError(
|
||||
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
err_msg_prefix = "decoder_" if in_decoder else ""
|
||||
raise ValueError(
|
||||
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
|
||||
if inputs_embeds is None:
|
||||
if self.embed_tokens is None:
|
||||
raise ValueError("You have to initialize the model with valid token embeddings")
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
batch_size, seq_length = input_shape
|
||||
device = inputs_embeds.device
|
||||
hidden_states = self.dropout(inputs_embeds)
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError(
|
||||
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
|
||||
|
||||
# initialize past_key_values with `None` if past does not exist
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(self.block)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
||||
present_key_value_states = () if use_cache else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
|
||||
past_key_value = past_key_values[i]
|
||||
layer_module = self.block[i]
|
||||
layer_head_mask = head_mask[i]
|
||||
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return tuple(module(*inputs, use_cache, output_attentions))
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
position_bias,
|
||||
encoder_hidden_states,
|
||||
encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias,
|
||||
layer_head_mask,
|
||||
cross_attn_layer_head_mask,
|
||||
None, # past_key_value is always None with gradient checkpointing
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
position_bias=position_bias,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
layer_head_mask=layer_head_mask,
|
||||
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
||||
|
||||
if use_cache is False or use_cache is None:
|
||||
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
# print(stage, len(layer_outputs), present_key_value_state.shape)
|
||||
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
||||
# (cross-attention position bias), (cross-attention weights)
|
||||
position_bias = layer_outputs[2]
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
# append next layer key value states
|
||||
if use_cache:
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
# last layer
|
||||
if at_last_stage:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
present_key_value_states,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
all_cross_attentions,
|
||||
] if v is not None)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_value_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
'hidden_states': hidden_states,
|
||||
'position_bias': position_bias,
|
||||
'encoder_decoder_position_bias': encoder_decoder_position_bias
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def t5_encoder_model_forward(
|
||||
self: T5EncoderModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.t5.modeling_gpt2.T5EncoderModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = T5PipelineForwards.t5_stack_forward(self.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
|
||||
return outputs
|
|
@ -1,3 +1,8 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
DropoutForParallelInput,
|
||||
Embedding1D,
|
||||
|
@ -8,9 +13,11 @@ from colossalai.shardformer.layer import (
|
|||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.t5 import T5PipelineForwards
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
||||
|
||||
class T5BasePolicy(Policy):
|
||||
|
@ -166,6 +173,123 @@ class T5BasePolicy(Policy):
|
|||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
|
||||
num_stages: int) -> Tuple[List[int], int]:
|
||||
"""
|
||||
Distribute t5 layers into stages when pipeline parallel is used.
|
||||
Return the layer distribution as a list and the starting stage of decoder.
|
||||
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||
"""
|
||||
|
||||
# number of encoder layers must be a positive integer
|
||||
if num_encoder_layers <= 0:
|
||||
raise ValueError("The number of encoder layers for T5 must be a positive integer.")
|
||||
|
||||
# number of layers should be large enough to fill in every stage
|
||||
if num_encoder_layers + num_decoder_layers < num_stages:
|
||||
raise ValueError("The total number of layers can't be smaller than number of stages.")
|
||||
|
||||
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||
if num_decoder_layers == 0:
|
||||
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
|
||||
# the number of stages distributed between encoder and decoder is optmized in this way:
|
||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
|
||||
def objective(num_encoder_stages):
|
||||
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
|
||||
|
||||
num_encoder_stages = 0
|
||||
optimal_diff = 2**31 - 1
|
||||
for i in range(1, num_stages):
|
||||
attempt = objective(i)
|
||||
if attempt < optimal_diff:
|
||||
num_encoder_stages = i
|
||||
optimal_diff = attempt
|
||||
num_decoder_stages = num_stages - num_encoder_stages
|
||||
|
||||
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
@staticmethod
|
||||
def get_t5_stage_index(layers_per_stage: List[int], stage: int,
|
||||
decoder_starting_stage: int) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
"""
|
||||
if stage < decoder_starting_stage:
|
||||
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
else:
|
||||
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
model = self.model
|
||||
encoder = self.model.encoder
|
||||
decoder = self.model.__dict__.get('decoder', None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
|
||||
decoder_starting_stage)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(model.shared)
|
||||
held_layers.append(encoder.embed_tokens)
|
||||
held_layers.append(encoder.dropout)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
held_layers.append(encoder.final_layer_norm)
|
||||
held_layers.append(encoder.dropout)
|
||||
held_layers.extend(encoder.block[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in t5's decoder
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.dropout)
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(decoder.final_layer_norm)
|
||||
held_layers.append(decoder.dropout)
|
||||
held_layers.extend(decoder.block[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
encoder = self.model.encoder
|
||||
decoder = self.model.__dict__.get('decoder', None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
|
@ -182,6 +306,15 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
target_key=T5Model)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
for dst in v:
|
||||
setattr_(self.model, dst, src)
|
||||
return self.model
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
|
@ -204,19 +337,55 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
target_key=T5ForConditionalGeneration)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
super().postprocess()
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {
|
||||
"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
for dst in v:
|
||||
setattr_(self.model, dst, src)
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
base_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=base_policy,
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel)
|
||||
return base_policy
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5EncoderModel,
|
||||
new_forward=T5PipelineForwards.t5_encoder_model_forward,
|
||||
policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
return super().get_held_layers()
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
for dst in v:
|
||||
setattr_(self.model, dst, src)
|
||||
return self.model
|
||||
|
|
|
@ -62,10 +62,8 @@ output_transform_fn = lambda x: x
|
|||
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(
|
||||
n_layer=2,
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
#n_embd=128,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
|
||||
|
||||
def test_t5_pipeline_distribution():
|
||||
num_test_cases = 8
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
|
||||
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
|
||||
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
|
||||
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
|
||||
test_dict['num_decoder_layers'][i],
|
||||
test_dict['num_stages'][i])
|
||||
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
|
||||
|
||||
|
||||
def test_t5_pipeline_layers():
|
||||
num_test_cases = 4
|
||||
test_dict = {
|
||||
'num_encoder_layers': [2, 3, 2, 4],
|
||||
'num_decoder_layers': [2, 0, 2, 8],
|
||||
'num_stages': [2, 2, 4, 4],
|
||||
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
|
||||
[[0, 4], [0, 3], [3, 6], [6, 8]]]
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
|
||||
|
||||
for stage in range(test_dict['num_stages'][i]):
|
||||
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
|
||||
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
|
||||
decoder_starting_stage)
|
||||
assert start_idx == predicted_start
|
||||
assert end_idx == predicted_end
|
|
@ -29,9 +29,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
input_ids = inputs['input_ids']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = 768
|
||||
hidden_size = sharded_model.config.n_embd
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
|
@ -40,12 +42,12 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name != 'transformers_gpt':
|
||||
if name == 'transformers_gpt':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_t5.py
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
if name != 'transformers_t5_encoder_model':
|
||||
continue
|
||||
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
input_ids = inputs['input_ids']
|
||||
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = sharded_model.config.d_model
|
||||
num_heads = sharded_model.config.num_heads
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
|
||||
|
||||
num_encoder_layers = len(sharded_model.encoder.block)
|
||||
decoder = sharded_model.__dict__.get('decoder', None)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
|
||||
stage = stage_manager.stage
|
||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
||||
|
||||
if not at_first_stage:
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
position_bias = torch.zeros(*position_bias_shape).cuda()
|
||||
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
inputs['position_bias'] = position_bias
|
||||
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if at_last_stage:
|
||||
if name != 'transformers_t5_for_conditional_generation':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
# position_bias information should be passed in T5
|
||||
assert 'position_bias' in output
|
||||
assert 'encoder_decoder_position_bias' in output
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_t5():
|
||||
spawn(check_t5, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
Loading…
Reference in New Issue