|
|
@ -1,5 +1,3 @@ |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nn import Parameter |
|
|
|
from torch.nn import Parameter |
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm |
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm |
|
|
|
|
|
|
|
|
|
|
@ -13,8 +11,6 @@ from colossalai.inference.modeling.models.nopadding_llama import ( |
|
|
|
) |
|
|
|
) |
|
|
|
from colossalai.inference.utils import init_to_get_rotary |
|
|
|
from colossalai.inference.utils import init_to_get_rotary |
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription |
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription |
|
|
|
|
|
|
|
|
|
|
|
# import colossalai |
|
|
|
|
|
|
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy |
|
|
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -45,27 +41,18 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): |
|
|
|
] |
|
|
|
] |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.shard_config._infer() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_causal_lm_forward |
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)} |
|
|
|
|
|
|
|
self.append_or_create_method_replacement( |
|
|
|
self.append_or_create_method_replacement( |
|
|
|
description=method_replacement, policy=policy, target_key=LlamaForCausalLM |
|
|
|
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_model_forward |
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)} |
|
|
|
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_decoder_layer_forward |
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)} |
|
|
|
|
|
|
|
self.append_or_create_method_replacement( |
|
|
|
self.append_or_create_method_replacement( |
|
|
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer |
|
|
|
description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
self.append_or_create_method_replacement( |
|
|
|
|
|
|
|
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
self.append_or_create_method_replacement( |
|
|
|
|
|
|
|
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
infer_forward = llama_rmsnorm_forward |
|
|
|
|
|
|
|
method_replacement = {"forward": partial(infer_forward)} |
|
|
|
|
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return policy |
|
|
|
return policy |
|
|
|
|
|
|
|
|
|
|
|