mirror of https://github.com/hpcaitech/ColossalAI
418 lines
16 KiB
Python
418 lines
16 KiB
Python
import warnings
|
|
from functools import partial
|
|
from typing import Callable, Dict, List
|
|
|
|
from torch import Tensor, nn
|
|
from torch.nn import Module
|
|
|
|
import colossalai.shardformer.layer as col_nn
|
|
|
|
from ..modeling.falcon import (
|
|
FalconPipelineForwards,
|
|
build_falcon_alibi_tensor_fn,
|
|
get_lm_forward_with_dist_cross_entropy,
|
|
get_tp_falcon_decoder_layer_forward,
|
|
)
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ["FalconPolicy"]
|
|
|
|
|
|
class FalconPolicy(Policy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
self.tie_weight = self.tie_weight_check()
|
|
return self.model
|
|
|
|
def module_policy(self):
|
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
|
|
|
|
if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
|
|
warnings.warn(
|
|
"Falcon doesn't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag."
|
|
)
|
|
self.shard_config.enable_tensor_parallelism = False
|
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
self.shard_config.enable_sequence_parallelism = False
|
|
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
|
|
|
policy = {}
|
|
|
|
embedding_cls = None
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
embedding_cls = col_nn.VocabParallelEmbedding1D
|
|
else:
|
|
if self.tie_weight:
|
|
embedding_cls = col_nn.PaddingEmbedding
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
assert (
|
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
|
assert (
|
|
self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0
|
|
), f"The number of key_value heads must be divisible by tensor parallel size."
|
|
attn_attribute_replacement = {
|
|
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
"self_attention.num_heads": self.model.config.num_attention_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
"self_attention.num_kv_heads": self.model.config.num_kv_heads // self.shard_config.tensor_parallel_size,
|
|
}
|
|
|
|
policy[FalconDecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement=attn_attribute_replacement,
|
|
method_replacement={"forward": get_tp_falcon_decoder_layer_forward()},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.query_key_value",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.dense",
|
|
target_module=col_nn.Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attention.attention_dropout",
|
|
target_module=col_nn.DropoutForParallelInput,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.dense_h_to_4h",
|
|
target_module=col_nn.Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row),
|
|
],
|
|
)
|
|
|
|
policy[FalconModel] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
},
|
|
method_replacement={
|
|
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
|
|
},
|
|
)
|
|
|
|
if embedding_cls is not None:
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="word_embeddings",
|
|
target_module=embedding_cls,
|
|
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=FalconModel,
|
|
)
|
|
|
|
# optimization configuration
|
|
if self.shard_config.enable_fused_normalization:
|
|
# handle falcon model
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="ln_f",
|
|
target_module=col_nn.FusedLayerNorm,
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=FalconModel,
|
|
)
|
|
|
|
# handle falcon decoder layer
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="ln_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="ln_mlp", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="input_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=FalconDecoderLayer,
|
|
)
|
|
|
|
if self.shard_config.enable_flash_attention:
|
|
warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
|
to customized forward method, and add this changing to policy."""
|
|
if self.pipeline_stage_manager:
|
|
stage_manager = self.pipeline_stage_manager
|
|
if self.model.__class__.__name__ == "FalconModel":
|
|
module = self.model
|
|
else:
|
|
module = self.model.transformer
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
|
method_replacement = {
|
|
"forward": partial(
|
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
|
)
|
|
}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=model_cls
|
|
)
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
assert self.pipeline_stage_manager is not None
|
|
if self.model.__class__.__name__ == "FalconModel":
|
|
module = self.model
|
|
else:
|
|
module = self.model.transformer
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = []
|
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
|
if stage_manager.is_first_stage():
|
|
held_layers.append(module.word_embeddings)
|
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
|
held_layers.extend(module.h[start_idx:end_idx])
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(module.ln_f)
|
|
|
|
return held_layers
|
|
|
|
|
|
class FalconModelPolicy(FalconPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
policy = super().module_policy()
|
|
|
|
from transformers.models.falcon.modeling_falcon import FalconModel
|
|
|
|
if self.pipeline_stage_manager:
|
|
self.set_pipeline_forward(
|
|
model_cls=FalconModel, new_forward=FalconPipelineForwards.falcon_model_forward, policy=policy
|
|
)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""
|
|
get pipeline layers for current stage
|
|
"""
|
|
held_layers = super().get_held_layers()
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""no shared params in falcon model"""
|
|
return []
|
|
|
|
|
|
class FalconForCausalLMPolicy(FalconPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
from transformers.models.falcon.modeling_falcon import FalconForCausalLM
|
|
|
|
policy = super().module_policy()
|
|
|
|
# handle tensor parallelism
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=col_nn.VocabParallelLMHead1D,
|
|
kwargs=dict(
|
|
gather_output=not self.shard_config.parallel_output,
|
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
|
),
|
|
),
|
|
policy=policy,
|
|
target_key=FalconForCausalLM,
|
|
)
|
|
if self.shard_config.parallel_output:
|
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=FalconForCausalLM
|
|
)
|
|
|
|
else:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="lm_head",
|
|
target_module=col_nn.PaddingLMHead,
|
|
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
|
),
|
|
policy=policy,
|
|
target_key=FalconForCausalLM,
|
|
)
|
|
|
|
if self.pipeline_stage_manager:
|
|
self.set_pipeline_forward(
|
|
model_cls=FalconForCausalLM,
|
|
new_forward=FalconPipelineForwards.falcon_for_causal_lm_forward,
|
|
policy=policy,
|
|
)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = super().get_held_layers()
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.lm_head)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
falcon_model = self.model
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
|
if id(falcon_model.transformer.word_embeddings.weight) == id(falcon_model.lm_head.weight):
|
|
# tie weights
|
|
return [
|
|
{
|
|
0: falcon_model.transformer.word_embeddings.weight,
|
|
self.pipeline_stage_manager.num_stages - 1: falcon_model.lm_head.weight,
|
|
}
|
|
]
|
|
return []
|
|
|
|
|
|
class FalconForSequenceClassificationPolicy(FalconPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
from transformers.models.falcon.modeling_falcon import FalconForSequenceClassification
|
|
|
|
policy = super().module_policy()
|
|
|
|
# handle tensor parallelism
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
|
),
|
|
policy=policy,
|
|
target_key=FalconForSequenceClassification,
|
|
)
|
|
|
|
if self.pipeline_stage_manager:
|
|
self.set_pipeline_forward(
|
|
model_cls=FalconForSequenceClassification,
|
|
new_forward=FalconPipelineForwards.falcon_for_sequence_classification_forward,
|
|
policy=policy,
|
|
)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = super().get_held_layers()
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.score)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""No shared params in falcon for sequence classification model"""
|
|
return []
|
|
|
|
|
|
class FalconForTokenClassificationPolicy(FalconPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
from transformers.models.falcon.modeling_falcon import FalconForTokenClassification
|
|
|
|
policy = super().module_policy()
|
|
|
|
# handle tensor parallelism
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="dropout",
|
|
target_module=col_nn.DropoutForReplicatedInput,
|
|
),
|
|
],
|
|
policy=policy,
|
|
target_key=FalconForTokenClassification,
|
|
)
|
|
if self.pipeline_stage_manager:
|
|
self.set_pipeline_forward(
|
|
model_cls=FalconForTokenClassification,
|
|
new_forward=FalconPipelineForwards.falcon_for_token_classification_forward,
|
|
policy=policy,
|
|
)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = super().get_held_layers()
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.dropout)
|
|
held_layers.append(self.model.classifier)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""No shared params in falcon for token classification model"""
|
|
return []
|
|
|
|
|
|
class FalconForQuestionAnsweringPolicy(FalconPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
from transformers.models.falcon.modeling_falcon import FalconForQuestionAnswering
|
|
|
|
policy = super().module_policy()
|
|
|
|
# handle tensor parallelism
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
|
),
|
|
policy=policy,
|
|
target_key=FalconForQuestionAnswering,
|
|
)
|
|
if self.pipeline_stage_manager:
|
|
self.set_pipeline_forward(
|
|
model_cls=FalconForQuestionAnswering,
|
|
new_forward=FalconPipelineForwards.falcon_for_question_answering_forward,
|
|
policy=policy,
|
|
)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
held_layers = super().get_held_layers()
|
|
stage_manager = self.pipeline_stage_manager
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.qa_outputs)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""No shared params in falcon for question answering model"""
|
|
return []
|