diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index 9e370d733..98be1be48 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -114,18 +114,29 @@ class MetaInfoProp(torch.fx.Interpreter):
                 return TensorMetadata(None, None, False, None, 0, False)
 
         meta = _map_aggregate(result, extract_tensor_meta)
-
         n.meta['tensor_meta'] = meta
-        total_node_size = _compute_node_numel(n.meta['tensor_meta'])
-        # counting the total size of parameters
+
+        # get byte size for each element
+        size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size()
+
+        # compute the total size of activation tensors
+        total_activation_size = _compute_node_numel(n.meta['tensor_meta'])
+
+        # compute the total size of model parameters
         total_param_size = 0
         if n.op == 'call_module':
             target_module = n.graph.owning_module.get_submodule(n.target)
             for param in target_module.parameters():
                 total_param_size += param.numel()
 
-        total_node_size += total_param_size
-        n.node_size = total_node_size
+        # compute the total memory cost of activation tensors and model parameters
+        total_activation_size *= size_per_elem_bytes
+        total_param_size *= size_per_elem_bytes
+
+        # TODO: node.node_size is not an original attribute
+        setattr(n, 'node_size', total_activation_size + total_param_size)
+        setattr(n, 'param_size', total_param_size)
+        setattr(n, 'activation_size', total_activation_size)
         n.meta['type'] = type(result)
         return result
 
diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py
index 84cef23b0..1da4f6b3b 100644
--- a/tests/test_fx/test_meta_info_prop.py
+++ b/tests/test_fx/test_meta_info_prop.py
@@ -23,12 +23,24 @@ def test_meta_info_prop():
     input_sample = torch.rand(BATCH_SIZE, DIM_IN)
     orig_output = model(input_sample)
     gm = symbolic_trace(model)
+    for node in gm.graph.nodes:
+        assert not hasattr(node,
+                           'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
+        assert not hasattr(node,
+                           'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure'
+        assert not hasattr(
+            node,
+            'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure'
     MetaInfoProp(gm).run(input_sample)
     for node in gm.graph.nodes:
         if node.op == 'placeholder':
             meta_check(node.meta['tensor_meta'], input_sample)
         if node.op == 'output':
             meta_check(node.meta['tensor_meta'], orig_output)
+        assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
+        assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure'
+        assert hasattr(
+            node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure'
 
 
 if __name__ == '__main__':