|
|
@ -1,5 +1,3 @@
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nn import Parameter
|
|
|
|
from torch.nn import Parameter
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
|
|
|
|
|
|
|
|
|
|
@ -13,8 +11,6 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from colossalai.inference.utils import init_to_get_rotary
|
|
|
|
from colossalai.inference.utils import init_to_get_rotary
|
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
|
|
|
|
|
|
|
|
|
|
|
# import colossalai
|
|
|
|
|
|
|
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
|
|
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -45,27 +41,18 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|
|
|
]
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.shard_config._infer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_causal_lm_forward
|
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
|
|
|
|
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_model_forward
|
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_decoder_layer_forward
|
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
|
|
|
description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
|
|
|
|
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
|
|
|
|
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_rmsnorm_forward
|
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
|
|
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return policy
|
|
|
|
return policy
|
|
|
|
|
|
|
|
|
|
|
|