mirror of https://github.com/hpcaitech/ColossalAI
[fx] update MetaInforProp pass to process more complex node.meta (#1344)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] update MetaInforProp pass to process more complex node.meta
pull/1350/head
parent
7a8702c06d
commit
051592c64e
|
@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
|||
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
|
||||
|
||||
|
||||
def _compute_node_numel(node_metadata: any) -> int:
|
||||
"""
|
||||
Compute numel of a node with ``tensor_meta`` attribute.
|
||||
"""
|
||||
node_numel = 0
|
||||
|
||||
if isinstance(node_metadata, TensorMetadata):
|
||||
node_numel += node_metadata.numel
|
||||
elif isinstance(node_metadata, dict):
|
||||
value_list = [v for _, v in node_metadata.items()]
|
||||
node_numel += _compute_node_numel(value_list)
|
||||
else:
|
||||
for element in node_metadata:
|
||||
node_numel += _compute_node_numel(element)
|
||||
|
||||
return node_numel
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
|
@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
return obj
|
||||
|
||||
meta = map_aggregate(result, extract_tensor_meta)
|
||||
|
||||
if found_tensor:
|
||||
n.meta['tensor_meta'] = meta
|
||||
else:
|
||||
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
||||
# counting the total size of node outputs
|
||||
total_node_size = 0
|
||||
if isinstance(n.meta['tensor_meta'], TensorMetadata):
|
||||
total_node_size += n.meta['tensor_meta'].numel
|
||||
else:
|
||||
for element in n.meta['tensor_meta']:
|
||||
assert isinstance(
|
||||
element, TensorMetadata
|
||||
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
|
||||
total_node_size += element.numel
|
||||
total_node_size = _compute_node_numel(n.meta['tensor_meta'])
|
||||
# counting the total size of parameters
|
||||
total_param_size = 0
|
||||
if n.op == 'call_module':
|
||||
|
|
Loading…
Reference in New Issue