diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index bb9a22b41..292a6e5ff 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,5 +1,3 @@ -from functools import partial - from torch.nn import Parameter from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm @@ -13,8 +11,6 @@ from colossalai.inference.modeling.models.nopadding_llama import ( ) from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription - -# import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -45,27 +41,18 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): ] ) - self.shard_config._infer() - - infer_forward = llama_causal_lm_forward - method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaForCausalLM + description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM ) - - infer_forward = llama_model_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = llama_decoder_layer_forward - method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel + ) + self.append_or_create_method_replacement( + description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer + ) + self.append_or_create_method_replacement( + description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm ) - - infer_forward = llama_rmsnorm_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm) return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 415fc6dd5..ad79394a9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -36,8 +36,6 @@ class ShardConfig: enable_sequence_overlap: bool = False parallel_output = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # pipeline_parallel_size: int - # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @property @@ -70,9 +68,3 @@ class ShardConfig: self.enable_jit_fused = True self.enable_sequence_parallelism = True self.enable_sequence_overlap = True - - def _infer(self): - """ - Set default params for inference. - """ - # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"