ColossalAI/colossalai/shardformer/policies/t5.py

449 lines
21 KiB
Python

import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from torch import Tensor, nn
from colossalai.shardformer.layer import (
DropoutForParallelInput,
Embedding1D,
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.t5 import (
T5PipelineForwards,
get_jit_fused_T5_layer_ff_forward,
get_t5_flash_attention_forward,
get_T5_layer_cross_attention_forward,
get_T5_layer_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
class T5BasePolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
"d_model":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"n_heads":
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
ignore_if_not_exist=True)
])
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
])
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention)
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
'forward': get_t5_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):
return self.model
@staticmethod
def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for T5 must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_t5_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
stage_manager = self.pipeline_stage_manager
model = self.model
encoder = self.model.encoder
decoder = getattr(self.model, 'decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
if stage_manager.is_first_stage():
held_layers.append(model.shared)
held_layers.append(encoder.embed_tokens)
held_layers.append(encoder.dropout)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.final_layer_norm)
held_layers.append(encoder.dropout)
held_layers.extend(encoder.block[start_idx:end_idx])
else:
# current stage is in t5's decoder
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.dropout)
if stage_manager.is_last_stage():
held_layers.append(decoder.final_layer_norm)
held_layers.append(decoder.dropout)
held_layers.extend(decoder.block[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
encoder = self.model.encoder
decoder = getattr(self.model, 'decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5Model)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
return []
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
],
policy=policy,
target_key=T5ForConditionalGeneration)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
shared_params = []
shared_embedding = {}
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
shared_embedding[0] = module.shared.weight
shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight
if id(module.lm_head.weight) == id(module.shared.weight):
shared_embedding[0] = module.shared.weight
shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)
return shared_params
return []
class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5EncoderModel
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=T5EncoderModel)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5EncoderModel,
new_forward=T5PipelineForwards.t5_encoder_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []