mirror of https://github.com/hpcaitech/ColossalAI
63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
![]() |
import torch.nn as nn
|
||
|
from torch.nn import Parameter
|
||
|
|
||
|
from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
|
||
|
from colossalai.inference.modeling.models.nopadding_llama import (
|
||
|
llama_causal_lm_forward,
|
||
|
llama_decoder_layer_forward,
|
||
|
llama_model_forward,
|
||
|
llama_rmsnorm_forward,
|
||
|
)
|
||
|
from colossalai.inference.utils import init_to_get_rotary
|
||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||
|
|
||
|
|
||
|
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
def module_policy(self):
|
||
|
policy = super().module_policy()
|
||
|
|
||
|
decoder_attribute_replacement = {
|
||
|
"lm_head.weight": Parameter(
|
||
|
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
|
||
|
),
|
||
|
}
|
||
|
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
|
||
|
attribute_replacement=decoder_attribute_replacement,
|
||
|
)
|
||
|
|
||
|
policy["DecoderLayer"] = ModulePolicyDescription(
|
||
|
sub_module_replacement=[
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="mlp",
|
||
|
target_module=NopadBaichuanMLP,
|
||
|
),
|
||
|
SubModuleReplacementDescription(
|
||
|
suffix="self_attn",
|
||
|
target_module=NopadBaichuanAttention,
|
||
|
),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
|
||
|
)
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
|
||
|
)
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
|
||
|
)
|
||
|
self.append_or_create_method_replacement(
|
||
|
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
|
||
|
)
|
||
|
|
||
|
return policy
|
||
|
|
||
|
def postprocess(self):
|
||
|
init_to_get_rotary(self.model.model)
|
||
|
return self.model
|