[Fix/Inference] Remove unused and non-functional functions (#5543)

* [fix] remove unused func

* rm non-functional partial
pull/5567/head
Yuanheng Zhao 8 months ago committed by GitHub
parent a2878e39f4
commit 4bb5d8923a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,3 @@
from functools import partial
from torch.nn import Parameter from torch.nn import Parameter
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
@ -13,8 +11,6 @@ from colossalai.inference.modeling.models.nopadding_llama import (
) )
from colossalai.inference.utils import init_to_get_rotary from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -45,27 +41,18 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
] ]
) )
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
) )
infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm
) )
infer_forward = llama_rmsnorm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
return policy return policy

@ -36,8 +36,6 @@ class ShardConfig:
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
parallel_output = True parallel_output = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@property @property
@ -70,9 +68,3 @@ class ShardConfig:
self.enable_jit_fused = True self.enable_jit_fused = True
self.enable_sequence_parallelism = True self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True self.enable_sequence_overlap = True
def _infer(self):
"""
Set default params for inference.
"""
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"

Loading…
Cancel
Save