diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py
index d32cc58d4..429bf2175 100644
--- a/colossalai/utils/model/colo_init_context.py
+++ b/colossalai/utils/model/colo_init_context.py
@@ -12,17 +12,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
         super().__init__()
         self._lazy_memory_allocate = lazy_memory_allocate
 
-    def _pre_context_exec(self):
-        """ 
-        The Callback function when entering the context
-        """
-        pass
-
-    def _post_context_exec(self):
-        """The callback function when exiting context.
-        """
-        pass
-
     def _post_init_method(self, module: torch.nn.Module):
         """
         The function to call at the end of the constructor of each module.
diff --git a/colossalai/utils/model/pipelinable.py b/colossalai/utils/model/pipelinable.py
new file mode 100644
index 000000000..ba5bbddb3
--- /dev/null
+++ b/colossalai/utils/model/pipelinable.py
@@ -0,0 +1,211 @@
+import torch
+import functools
+from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
+from colossalai.builder.pipeline import partition_uniform, partition_balanced
+from colossalai.core import global_context as gpc
+
+
+class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
+
+    def __init__(self):
+        super().__init__()
+        self._layer_spec_dict = {}
+        self._root_children = None
+        self._model = None
+        self._layer_spec_list = []
+        self._func_dict = {}
+        self._policy = "balanced"
+
+    @property
+    def policy(self):
+        return self._policy
+
+    @property
+    def layers_count(self):
+        return len(self._layer_spec_list)
+
+    @property
+    def funcs_count(self):
+        return len(self._func_dict)
+
+    def _pre_context_exec(self):
+        """ 
+        The Callback function when entering the context
+        """
+
+        # reserve rng states
+        self.cpu_rng_state = torch.get_rng_state()
+        self.cuda_rng_state = torch.cuda.get_rng_state()
+
+    def _post_context_exec(self):
+        """
+        The callback function when exiting context.
+        """
+
+        # reset rng states
+        torch.set_rng_state(self.cpu_rng_state)
+        torch.cuda.set_rng_state(self.cuda_rng_state)
+
+    def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
+        """
+        The function to call at the end of the constructor of each module.
+        NOTE() The module may be passed to this function multiple times.
+        """
+        module_id = id(module)
+        modified_args = []
+        for obj in args:
+            if issubclass(obj.__class__, torch.nn.modules.module.Module):
+                obj = self._layer_spec_dict[id(obj)]
+            modified_args.append(obj)
+        # (lyl)TODO: analyse kwargs as well
+        modified_args = tuple(modified_args)
+        self._root_children = list(module.children())
+        self._model = module
+        layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs)
+        layer_spec.set_children(module.children())
+        self._layer_spec_dict[module_id] = layer_spec
+        for param in module.parameters(recurse=False):
+            param.data = torch.rand(1, 1)
+
+    def to_layer_list(self, exec_seq=None):
+        """
+        Create a layer spec list and func list with execution sequence given by user.
+        If exec_seq is None, we will take the module initizing order as execution order.
+        """
+        if exec_seq is None:
+            #if user do not provide the model executing sequence, we use the initialization order as the executing order.
+            for child in self._root_children:
+                layer_spec = self._layer_spec_dict[id(child)]
+                if layer_spec.typename in (torch.nn.modules.container.ModuleList,
+                                           torch.nn.modules.container.Sequential):
+                    for child_in_container in layer_spec.children:
+                        self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
+
+                else:
+                    self._layer_spec_list.append(layer_spec)
+
+        else:
+            func_key = "first"
+            for index, element in enumerate(exec_seq):
+                if isinstance(element, str):
+                    module = dict(self._model.named_modules())[element]
+                    layer_spec = self._layer_spec_dict[id(module)]
+                    func_key = layer_spec
+                    self._layer_spec_list.append(layer_spec)
+                else:
+                    if func_key not in self._func_dict:
+                        self._func_dict[func_key] = []
+                    self._func_dict[func_key].append(element)
+
+    def partition(self, num_chunks, pipeline_size, rank):
+        """
+        Partitioned model will be built respect to partion policy.
+        The real module instance will be built in this method.
+        """
+        if isinstance(self._policy, str):
+            if self._policy == "uniform":
+                parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
+            elif self._policy == "balanced":
+                param_counts = []
+                for layer_spec in self._layer_spec_list:
+                    param_counts.append(layer_spec.count_params())
+                parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
+            else:
+                raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
+        elif isinstance(self._policy, dict):
+            parts = self._policy[rank]
+        else:
+            raise ValueError("A partition policy should be either a string or a dictionary.")
+
+        layers_to_build = []
+        for start, end in parts:
+            layers_to_build += self._layer_spec_list[start:end]
+        func_dict_in_partition = {}
+        module_list_in_partition = []
+        if rank == 0 and "first" in self._func_dict:
+            func_dict_in_partition["first"] = self._func_dict["first"]
+        for layer in layers_to_build:
+            module = layer.build()
+            module_list_in_partition.append(module)
+            if layer in self._func_dict:
+                func_dict_in_partition[id(module)] = self._func_dict[layer]
+        module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
+        pipeline_model = PipelinableModel(module_list_in_partition, func_dict_in_partition)
+
+        return pipeline_model
+
+    def load_policy(self, policy):
+        self._policy = policy
+
+
+class PipelinableModel(torch.nn.Module):
+
+    def __init__(self, module_list, func_dict):
+        super().__init__()
+        self._module_list = module_list
+        self._func_dict = func_dict
+
+    def forward(self, input_tensor):
+        if "first" in self._func_dict:
+            funcs = self._func_dict["first"]
+            if isinstance(funcs, list):
+                for f in funcs:
+                    input_tensor = f(input_tensor)
+            else:
+                input_tensor = funcs(input_tensor)
+
+        for module in self._module_list:
+            input_tensor = module(input_tensor)
+            if id(module) in self._func_dict:
+                funcs = self._func_dict[id(module)]
+                if isinstance(funcs, list):
+                    for f in funcs:
+                        input_tensor = f(input_tensor)
+                else:
+                    input_tensor = funcs(input_tensor)
+
+        return input_tensor
+
+
+class LayerSpec:
+
+    def __init__(self, typename, *module_args, **module_kwargs):
+        self.typename = typename
+        self.module_args = module_args
+        self.module_kwargs = module_kwargs
+        self.children = None
+        self._param_count = 0
+
+        if not issubclass(typename, torch.nn.Module):
+            raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
+
+    def __repr__(self):
+        return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
+
+    @property
+    def param_count(self):
+        return self._param_count
+
+    def build(self):
+        """Build the stored specification."""
+
+        recovered_args = []
+        for obj in self.module_args:
+            if isinstance(obj, LayerSpec):
+                obj = obj.build()
+            recovered_args.append(obj)
+        recovered_args = tuple(recovered_args)
+        return self.typename(*recovered_args, **self.module_kwargs)
+
+    def set_children(self, children):
+        self.children = children
+
+    def count_params(self):
+        self._param_count = 0
+        layer = self.build()
+        for param in layer.parameters():
+            self._param_count += param.numel()
+        return self._param_count
+
+    def reset_param_count(self):
+        self._param_count = 0
diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py
index ced1365ec..50a75a110 100644
--- a/colossalai/utils/model/utils.py
+++ b/colossalai/utils/model/utils.py
@@ -9,6 +9,28 @@ def _substitute_init_recursively(cls, func):
         func(subcls)
 
 
+def call_to_str(base, *args, **kwargs):
+    """Construct a string representation of a call.
+
+    Args:
+        base (str): name of the call
+        args (tuple, optional): args to ``base``
+        kwargs (dict, optional): kwargs supplied to ``base``
+
+    Returns:
+        str: A string representation of base(*args, **kwargs)
+    """
+    name = f'{base}('
+    if args:
+        name += ', '.join(repr(arg) for arg in args)
+        if kwargs:
+            name += ', '
+    if kwargs:
+        name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
+    name += ')'
+    return name
+
+
 class InsertPostInitMethodToModuleSubClasses(object):
 
     def __init__(self, default_dtype: Optional[torch.dtype] = None):
@@ -28,7 +50,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
             @functools.wraps(f)
             def wrapper(module: torch.nn.Module, *args, **kwargs):
                 f(module, *args, **kwargs)
-                self._post_init_method(module)
+                self._post_init_method(module, *args, **kwargs)
 
             return wrapper
 
@@ -71,7 +93,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
             return False
 
     # To be implemented by inheriting classes
-    def _post_init_method(self, module):
+    def _post_init_method(self, module, *args, **kwargs):
         pass
 
     def _pre_context_exec(self):
diff --git a/tests/test_utils/test_pipelinable.py b/tests/test_utils/test_pipelinable.py
new file mode 100644
index 000000000..2be3b264c
--- /dev/null
+++ b/tests/test_utils/test_pipelinable.py
@@ -0,0 +1,64 @@
+import os.path as osp
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+from colossalai.utils.model.pipelinable import PipelinableContext
+
+from functools import partial
+from colossalai.utils import free_port
+from colossalai.testing import rerun_on_exception
+
+NUM_CHUNKS = 1
+PIPELINE_SIZE = 2
+
+
+class MLP(torch.nn.Module):
+
+    def __init__(self, dim: int = 256):
+        super().__init__()
+        intermediate_dim = dim * 4
+        self.dense_1 = torch.nn.Linear(dim, intermediate_dim)
+        self.activation = torch.nn.GELU()
+        self.dense_2 = torch.nn.Linear(intermediate_dim, dim)
+        self.dropout = torch.nn.Dropout(0.1)
+
+    def forward(self, x):
+        x = self.dense_1(x)
+        x = self.activation(x)
+        x = self.dense_2(x)
+        x = self.dropout(x)
+        return x
+
+
+def run_pipelinable(rank):
+    pipelinable = PipelinableContext()
+    with pipelinable:
+        model = MLP()
+
+    assert pipelinable.policy == "balanced"
+    pipelinable.load_policy("uniform")
+    assert pipelinable.policy == "uniform"
+    pipelinable.to_layer_list()
+
+    assert pipelinable.layers_count == len(list(model.children()))
+
+    pipeline_model_part_0 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 0)
+    assert isinstance(pipeline_model_part_0, torch.nn.Module)
+    pipeline_model_part_1 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 1)
+    assert isinstance(pipeline_model_part_1, torch.nn.Module)
+
+    layers_count_in_part_0 = len(list(pipeline_model_part_0._module_list))
+    layers_count_in_part_1 = len(list(pipeline_model_part_1._module_list))
+
+    assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
+
+
+@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
+def test_pipelinable():
+    mp.spawn(run_pipelinable, nprocs=1)
+
+
+if __name__ == '__main__':
+    test_pipelinable()