mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] adapted llama to the new API (#4036)
parent
74d176c8d8
commit
c1d5453e9f
@ -1,122 +1,121 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
LlamaDecoderLayer:
|
||||
Argument(attr_dict={
|
||||
"self_attn.hidden_size": config.hidden_size // world_size,
|
||||
"self_attn.num_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attn.hidden_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
],
|
||||
),
|
||||
LlamaModel:
|
||||
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_layer() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="self_attn.q_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="self_attn.k_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="self_attn.v_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="self_attn.o_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_layer() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="mlp.gate_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="mlp.up_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
gather_output=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="mlp.down_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embeddings() -> List:
|
||||
return [Col_Layer(
|
||||
suffix="embed_tokens",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def argument(config, world_size):
|
||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
|
||||
argument.update(llamapolicy)
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def lm_head() -> List:
|
||||
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
from transformers import LlamaForSequenceClassification
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument(config, world_size):
|
||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification:
|
||||
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
argument.update(llamapolicy)
|
||||
|
||||
@staticmethod
|
||||
def score() -> List:
|
||||
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
Loading…
Reference in new issue