diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 9e370d733..98be1be48 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -114,18 +114,29 @@ class MetaInfoProp(torch.fx.Interpreter): return TensorMetadata(None, None, False, None, 0, False) meta = _map_aggregate(result, extract_tensor_meta) - n.meta['tensor_meta'] = meta - total_node_size = _compute_node_numel(n.meta['tensor_meta']) - # counting the total size of parameters + + # get byte size for each element + size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() + + # compute the total size of activation tensors + total_activation_size = _compute_node_numel(n.meta['tensor_meta']) + + # compute the total size of model parameters total_param_size = 0 if n.op == 'call_module': target_module = n.graph.owning_module.get_submodule(n.target) for param in target_module.parameters(): total_param_size += param.numel() - total_node_size += total_param_size - n.node_size = total_node_size + # compute the total memory cost of activation tensors and model parameters + total_activation_size *= size_per_elem_bytes + total_param_size *= size_per_elem_bytes + + # TODO: node.node_size is not an original attribute + setattr(n, 'node_size', total_activation_size + total_param_size) + setattr(n, 'param_size', total_param_size) + setattr(n, 'activation_size', total_activation_size) n.meta['type'] = type(result) return result diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 84cef23b0..1da4f6b3b 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,12 +23,24 @@ def test_meta_info_prop(): input_sample = torch.rand(BATCH_SIZE, DIM_IN) orig_output = model(input_sample) gm = symbolic_trace(model) + for node in gm.graph.nodes: + assert not hasattr(node, + 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' + assert not hasattr(node, + 'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure' + assert not hasattr( + node, + 'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': meta_check(node.meta['tensor_meta'], input_sample) if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) + assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' + assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure' + assert hasattr( + node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure' if __name__ == '__main__':