2023-06-19 05:53:17 +00:00
|
|
|
from typing import Dict, Union
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-07-05 07:13:00 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-06-30 02:56:29 +00:00
|
|
|
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
|
|
|
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
class LlamaPolicy(Policy):
|
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def preprocess(self):
|
2023-07-10 02:48:53 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# Resize embedding
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
2023-06-19 05:53:17 +00:00
|
|
|
|
2023-07-10 02:48:53 +00:00
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
2023-06-19 05:53:17 +00:00
|
|
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
policy = {}
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
|
|
|
attribute_replacement={
|
|
|
|
"self_attn.hidden_size":
|
|
|
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"self_attn.num_heads":
|
|
|
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
|
|
},
|
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.o_proj",
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
),
|
2023-07-03 07:29:11 +00:00
|
|
|
SubModuleReplacementDescription(
|
2023-07-04 01:57:03 +00:00
|
|
|
suffix="mlp.gate_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.up_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.down_proj",
|
|
|
|
target_module=Linear1D_Row,
|
2023-07-03 07:29:11 +00:00
|
|
|
)
|
2023-07-04 01:57:03 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="embed_tokens",
|
|
|
|
target_module=VocabParallelEmbedding1D,
|
|
|
|
),
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaModel)
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
# optimization configuration
|
|
|
|
if self.shard_config.enable_fused_normalization:
|
2023-07-04 01:57:03 +00:00
|
|
|
self.append_or_create_submodule_replacement(description=[
|
2023-06-30 01:32:37 +00:00
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="input_layernorm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="post_attention_layernorm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
)
|
2023-07-04 01:57:03 +00:00
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaDecoderLayer)
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
|
|
|
suffix="norm",
|
|
|
|
target_module=FusedRMSNorm,
|
|
|
|
),
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaModel)
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
return policy
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers import LlamaForCausalLM
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-04 01:57:03 +00:00
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# add a new item for casual lm
|
|
|
|
new_item = {
|
|
|
|
LlamaForCausalLM:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
|
|
|
])
|
|
|
|
}
|
|
|
|
policy.update(new_item)
|
2023-06-19 05:53:17 +00:00
|
|
|
return policy
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers import LlamaForSequenceClassification
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
policy = super().module_policy()
|
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# add a new item for sequence classification
|
|
|
|
new_item = {
|
|
|
|
LlamaForSequenceClassification:
|
|
|
|
ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
|
|
|
])
|
|
|
|
}
|
|
|
|
policy.update(new_item)
|
2023-06-19 05:53:17 +00:00
|
|
|
return policy
|