mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] adapted llama to the new API (#4036)
parent
74d176c8d8
commit
c1d5453e9f
@ -1,122 +1,121 @@
|
|||||||
from dataclasses import dataclass
|
from typing import Dict, Union
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
import colossalai.shardformer.layer.layers as col_nn
|
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
class LlamaPolicy(Policy):
|
class LlamaPolicy(Policy):
|
||||||
|
|
||||||
@staticmethod
|
def preprocess(self):
|
||||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
# Resize embedding
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
return {
|
return {
|
||||||
LlamaDecoderLayer:
|
LlamaDecoderLayer:
|
||||||
Argument(attr_dict={
|
ModulePolicyDescription(
|
||||||
"self_attn.hidden_size": config.hidden_size // world_size,
|
attribute_replacement={
|
||||||
"self_attn.num_heads": config.num_attention_heads // world_size,
|
"self_attn.hidden_size":
|
||||||
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads":
|
||||||
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
},
|
},
|
||||||
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
|
param_replacement=[],
|
||||||
LlamaModel:
|
sub_module_replacement=[
|
||||||
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
|
SubModuleReplacementDescription(
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attn_layer() -> List:
|
|
||||||
return [
|
|
||||||
Col_Layer(
|
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Col,
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
),
|
||||||
Col_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Col,
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
),
|
||||||
Col_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Col,
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
),
|
),
|
||||||
Row_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Row,
|
||||||
bias="bias",
|
),
|
||||||
replace_layer=col_nn.Linear1D_Row,
|
SubModuleReplacementDescription(
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mlp_layer() -> List:
|
|
||||||
return [
|
|
||||||
Col_Layer(
|
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Col,
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
|
||||||
gather_output=True,
|
|
||||||
),
|
),
|
||||||
Col_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Col,
|
||||||
bias="bias",
|
|
||||||
replace_layer=col_nn.Linear1D_Row,
|
|
||||||
gather_output=True,
|
|
||||||
),
|
),
|
||||||
Col_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
weight="weight",
|
target_module=Linear1D_Row,
|
||||||
bias="bias",
|
)
|
||||||
replace_layer=col_nn.Linear1D_Col,
|
],
|
||||||
gather_output=True,
|
|
||||||
),
|
),
|
||||||
]
|
LlamaModel:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
@staticmethod
|
param_replacement=[],
|
||||||
def embeddings() -> List:
|
sub_module_replacement=[
|
||||||
return [Col_Layer(
|
SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="embed_tokens",
|
||||||
weight="weight",
|
target_module=VocabParallelEmbedding1D,
|
||||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
)
|
||||||
)]
|
])
|
||||||
|
}
|
||||||
from transformers import LlamaForCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
def new_model_class(self):
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
def postprocess(self):
|
||||||
def argument(config, world_size):
|
return self.model
|
||||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
|
||||||
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
|
|
||||||
argument.update(llamapolicy)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def lm_head() -> List:
|
|
||||||
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
|
||||||
|
|
||||||
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
|
|
||||||
from transformers import LlamaForSequenceClassification
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
LlamaForCausalLM:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="lm_head",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
|
|
||||||
@staticmethod
|
def module_policy(self):
|
||||||
def argument(config, world_size):
|
policy = super().module_policy()
|
||||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
|
||||||
argument = {
|
# add a new item for sequence classification
|
||||||
|
new_item = {
|
||||||
LlamaForSequenceClassification:
|
LlamaForSequenceClassification:
|
||||||
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="score",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
}
|
}
|
||||||
argument.update(llamapolicy)
|
policy.update(new_item)
|
||||||
|
return policy
|
||||||
@staticmethod
|
|
||||||
def score() -> List:
|
|
||||||
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
|
||||||
|
Loading…
Reference in new issue