mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
11 KiB
266 lines
11 KiB
from functools import partial
|
|
from typing import Callable, Dict, List, Union
|
|
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
|
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
|
|
|
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
|
|
|
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
|
|
|
|
|
class LlamaPolicy(Policy):
|
|
|
|
def config_sanity_check(self):
|
|
pass
|
|
|
|
def preprocess(self):
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
# Resize embedding
|
|
vocab_size = self.model.config.vocab_size
|
|
world_size = self.shard_config.tensor_parallel_size
|
|
|
|
if vocab_size % world_size != 0:
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
|
|
|
return self.model
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
|
|
|
policy = {}
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement={
|
|
"self_attn.hidden_size":
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
"self_attn.num_heads":
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
},
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.o_proj",
|
|
target_module=Linear1D_Row,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.gate_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.up_proj",
|
|
target_module=Linear1D_Col,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.down_proj",
|
|
target_module=Linear1D_Row,
|
|
)
|
|
],
|
|
)
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
suffix="embed_tokens",
|
|
target_module=VocabParallelEmbedding1D,
|
|
),
|
|
policy=policy,
|
|
target_key=LlamaModel)
|
|
|
|
# optimization configuration
|
|
if self.shard_config.enable_fused_normalization:
|
|
self.append_or_create_submodule_replacement(description=[
|
|
SubModuleReplacementDescription(
|
|
suffix="input_layernorm",
|
|
target_module=FusedRMSNorm,
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="post_attention_layernorm",
|
|
target_module=FusedRMSNorm,
|
|
)
|
|
],
|
|
policy=policy,
|
|
target_key=LlamaDecoderLayer)
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
suffix="norm",
|
|
target_module=FusedRMSNorm,
|
|
),
|
|
policy=policy,
|
|
target_key=LlamaModel)
|
|
|
|
if self.shard_config.enable_flash_attention:
|
|
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
|
|
'forward': get_llama_flash_attention_forward(),
|
|
})
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
|
to customized forward method, and add this changing to policy."""
|
|
if self.pipeline_stage_manager:
|
|
stage_manager = self.pipeline_stage_manager
|
|
if self.model.__class__.__name__ == "LlamaModel":
|
|
module = self.model
|
|
else:
|
|
module = self.model.model
|
|
|
|
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
|
self.append_or_create_method_replacement(description=method_replacement,
|
|
policy=policy,
|
|
target_key=model_cls)
|
|
|
|
return
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
assert self.pipeline_stage_manager is not None
|
|
|
|
if self.model.__class__.__name__ == 'LlamaModel':
|
|
module = self.model
|
|
else:
|
|
module = self.model.model
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
held_layers = []
|
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
|
if stage_manager.is_first_stage():
|
|
held_layers.append(module.embed_tokens)
|
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
held_layers.extend(module.layers[start_idx:end_idx])
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(module.norm)
|
|
|
|
return held_layers
|
|
|
|
|
|
class LlamaModelPolicy(LlamaPolicy):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
policy = super().module_policy()
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
|
if self.pipeline_stage_manager:
|
|
# set None as default
|
|
self.set_pipeline_forward(model_cls=LlamaModel,
|
|
new_forward=LlamaPipelineForwards.llama_model_forward,
|
|
policy=policy)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
held_layers = super().get_held_layers()
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""No shared params in llama model"""
|
|
return []
|
|
|
|
|
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
|
|
|
def module_policy(self):
|
|
from transformers import LlamaForCausalLM
|
|
|
|
policy = super().module_policy()
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
# add a new item for casual lm
|
|
new_item = {
|
|
LlamaForCausalLM:
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
|
])
|
|
}
|
|
policy.update(new_item)
|
|
|
|
if self.pipeline_stage_manager:
|
|
# set None as default
|
|
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
|
|
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
|
|
policy=policy)
|
|
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = super().get_held_layers()
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.lm_head)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
llama_model = self.model.model
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
|
if id(llama_model.embed_tokens.weight) == id(
|
|
self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1:
|
|
# tie weights
|
|
return [{
|
|
0: llama_model.embed_tokens.weight,
|
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
|
|
}]
|
|
return []
|
|
|
|
|
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|
|
|
def module_policy(self):
|
|
from transformers import LlamaForSequenceClassification
|
|
|
|
policy = super().module_policy()
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
# add a new item for sequence classification
|
|
new_item = {
|
|
LlamaForSequenceClassification:
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
|
])
|
|
}
|
|
policy.update(new_item)
|
|
# to be confirmed
|
|
if self.pipeline_stage_manager:
|
|
# set None as default
|
|
self.set_pipeline_forward(model_cls=LlamaForSequenceClassification,
|
|
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
|
policy=policy)
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
stage_manager = self.pipeline_stage_manager
|
|
held_layers = super().get_held_layers()
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(self.model.score)
|
|
return held_layers
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
"""No shared params in llama for sequence classification model"""
|
|
return []
|