diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index 646b3cede..96c2b15ee 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -1,6 +1,6 @@
 import time
 from itertools import count
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Type, Union
 
 import numpy as np
 import torch
@@ -64,7 +64,7 @@ class InferenceEngine:
         tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
         inference_config: InferenceConfig,
         verbose: bool = False,
-        model_policy: Policy = None,
+        model_policy: Union[Policy, Type[Policy]] = None,
     ) -> None:
         self.inference_config = inference_config
         self.dtype = inference_config.dtype
@@ -105,7 +105,7 @@ class InferenceEngine:
 
         self._verify_args()
 
-    def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
+    def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
         """
         Shard model or/and Load weight
 
@@ -150,11 +150,17 @@ class InferenceEngine:
             )
 
         if model_policy is None:
-            if self.inference_config.pad_input:
-                model_type = "padding_" + self.model_config.model_type
-            else:
-                model_type = "nopadding_" + self.model_config.model_type
-            model_policy = model_policy_map[model_type]()
+            prefix = "nopadding" if not self.inference_config.pad_input else "padding"
+            model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
+            model_policy = model_policy_map.get(model_policy_key)
+
+        if not isinstance(model_policy, Policy):
+            try:
+                model_policy = model_policy()
+            except Exception as e:
+                raise ValueError(f"Unable to instantiate model policy: {e}")
+
+        assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
         pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
         tp_group = pg_mesh.get_group_along_axis(TP_AXIS)