diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 4033cd72b..79e4927a2 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: return TensorMetadata(shape, dtype, requires_grad, stride, numel) +def _compute_node_numel(node_metadata: any) -> int: + """ + Compute numel of a node with ``tensor_meta`` attribute. + """ + node_numel = 0 + + if isinstance(node_metadata, TensorMetadata): + node_numel += node_metadata.numel + elif isinstance(node_metadata, dict): + value_list = [v for _, v in node_metadata.items()] + node_numel += _compute_node_numel(value_list) + else: + for element in node_metadata: + node_numel += _compute_node_numel(element) + + return node_numel + + @compatibility(is_backward_compatible=True) class MetaInfoProp(torch.fx.Interpreter): """ @@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter): return obj meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: n.meta['tensor_meta'] = meta else: n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0) # counting the total size of node outputs - total_node_size = 0 - if isinstance(n.meta['tensor_meta'], TensorMetadata): - total_node_size += n.meta['tensor_meta'].numel - else: - for element in n.meta['tensor_meta']: - assert isinstance( - element, TensorMetadata - ), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata." - total_node_size += element.numel + total_node_size = _compute_node_numel(n.meta['tensor_meta']) # counting the total size of parameters total_param_size = 0 if n.op == 'call_module':