diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 98be1be48..f8fa60249 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -36,20 +36,20 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor) -def _compute_node_numel(node_metadata: any) -> int: +def _compute_activation_size(node_metadata: any) -> int: """ Compute numel of a node with ``tensor_meta`` attribute. """ node_numel = 0 if isinstance(node_metadata, TensorMetadata): - node_numel += node_metadata.numel + node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size() elif isinstance(node_metadata, dict): value_list = [v for _, v in node_metadata.items()] - node_numel += _compute_node_numel(value_list) + node_numel += _compute_activation_size(value_list) else: for element in node_metadata: - node_numel += _compute_node_numel(element) + node_numel += _compute_activation_size(element) return node_numel @@ -105,6 +105,7 @@ class MetaInfoProp(torch.fx.Interpreter): """ def run_node(self, n: Node) -> Any: + # TODO: We might run_node(n) with meta data, and count FLOPS for each node result = super().run_node(n) def extract_tensor_meta(obj): @@ -116,24 +117,20 @@ class MetaInfoProp(torch.fx.Interpreter): meta = _map_aggregate(result, extract_tensor_meta) n.meta['tensor_meta'] = meta - # get byte size for each element - size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size() - - # compute the total size of activation tensors - total_activation_size = _compute_node_numel(n.meta['tensor_meta']) - - # compute the total size of model parameters + total_activation_size = 0 total_param_size = 0 if n.op == 'call_module': target_module = n.graph.owning_module.get_submodule(n.target) + if not getattr(target_module, 'inplace', False): + total_activation_size = _compute_activation_size(n.meta['tensor_meta']) for param in target_module.parameters(): - total_param_size += param.numel() + total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + elif n.op == 'call_function': + if 'inplace' not in n.kwargs: + total_activation_size = _compute_activation_size(n.meta['tensor_meta']) + else: + total_activation_size = _compute_activation_size(n.meta['tensor_meta']) - # compute the total memory cost of activation tensors and model parameters - total_activation_size *= size_per_elem_bytes - total_param_size *= size_per_elem_bytes - - # TODO: node.node_size is not an original attribute setattr(n, 'node_size', total_activation_size + total_param_size) setattr(n, 'param_size', total_param_size) setattr(n, 'activation_size', total_activation_size)