You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/modeling/policy/nopadding_llama.py

116 lines
4.6 KiB

from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.nopadding_llama import (
NopadLlamaAttention,
NopadLlamaMLP,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
else:
decoder_attribute_replacement = None
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadLlamaMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadLlamaAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
),
],
)
policy[LlamaForCausalLM] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
)
],
)
# self.shard_config._infer()
self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
)
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
return self.model
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy":
return NoPaddingLlamaModelInferPolicy()