mirror of https://github.com/hpcaitech/ColossalAI
[fx] update MetaInforProp pass to process more complex node.meta (#1344)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] update MetaInforProp pass to process more complex node.meta
pull/1350/head
parent
7a8702c06d
commit
051592c64e
|
@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
||||||
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
|
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_node_numel(node_metadata: any) -> int:
|
||||||
|
"""
|
||||||
|
Compute numel of a node with ``tensor_meta`` attribute.
|
||||||
|
"""
|
||||||
|
node_numel = 0
|
||||||
|
|
||||||
|
if isinstance(node_metadata, TensorMetadata):
|
||||||
|
node_numel += node_metadata.numel
|
||||||
|
elif isinstance(node_metadata, dict):
|
||||||
|
value_list = [v for _, v in node_metadata.items()]
|
||||||
|
node_numel += _compute_node_numel(value_list)
|
||||||
|
else:
|
||||||
|
for element in node_metadata:
|
||||||
|
node_numel += _compute_node_numel(element)
|
||||||
|
|
||||||
|
return node_numel
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class MetaInfoProp(torch.fx.Interpreter):
|
class MetaInfoProp(torch.fx.Interpreter):
|
||||||
"""
|
"""
|
||||||
|
@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
meta = map_aggregate(result, extract_tensor_meta)
|
meta = map_aggregate(result, extract_tensor_meta)
|
||||||
|
|
||||||
if found_tensor:
|
if found_tensor:
|
||||||
n.meta['tensor_meta'] = meta
|
n.meta['tensor_meta'] = meta
|
||||||
else:
|
else:
|
||||||
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
||||||
# counting the total size of node outputs
|
# counting the total size of node outputs
|
||||||
total_node_size = 0
|
total_node_size = _compute_node_numel(n.meta['tensor_meta'])
|
||||||
if isinstance(n.meta['tensor_meta'], TensorMetadata):
|
|
||||||
total_node_size += n.meta['tensor_meta'].numel
|
|
||||||
else:
|
|
||||||
for element in n.meta['tensor_meta']:
|
|
||||||
assert isinstance(
|
|
||||||
element, TensorMetadata
|
|
||||||
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
|
|
||||||
total_node_size += element.numel
|
|
||||||
# counting the total size of parameters
|
# counting the total size of parameters
|
||||||
total_param_size = 0
|
total_param_size = 0
|
||||||
if n.op == 'call_module':
|
if n.op == 'call_module':
|
||||||
|
|
Loading…
Reference in New Issue