[fx] update MetaInforProp pass to process more complex node.meta (#1344)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] update MetaInforProp pass to process more complex node.meta
pull/1350/head
YuliangLiu0306 2022-07-21 10:57:52 +08:00 committed by GitHub
parent 7a8702c06d
commit 051592c64e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 9 deletions

View File

@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
def _compute_node_numel(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_node_numel(value_list)
else:
for element in node_metadata:
node_numel += _compute_node_numel(element)
return node_numel
@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return obj
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
# counting the total size of node outputs
total_node_size = 0
if isinstance(n.meta['tensor_meta'], TensorMetadata):
total_node_size += n.meta['tensor_meta'].numel
else:
for element in n.meta['tensor_meta']:
assert isinstance(
element, TensorMetadata
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
total_node_size += element.numel
total_node_size = _compute_node_numel(n.meta['tensor_meta'])
# counting the total size of parameters
total_param_size = 0
if n.op == 'call_module':