[fx] fix MetaInfoProp for incorrect calculations and add detections for inplace op. (#1466)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] add rules to linearize computation graphs for searching. (#2)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.

* [fx] fix MetaInfoProp.

* [fx] fix MetaInfoProp.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.
pull/1467/head
Super Daniel 2022-08-18 11:27:06 +08:00 committed by GitHub
parent e7383f578b
commit bbc58d881b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 17 deletions

View File

@ -36,20 +36,20 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
def _compute_node_numel(node_metadata: any) -> int:
def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_node_numel(value_list)
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_node_numel(element)
node_numel += _compute_activation_size(element)
return node_numel
@ -105,6 +105,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
result = super().run_node(n)
def extract_tensor_meta(obj):
@ -116,24 +117,20 @@ class MetaInfoProp(torch.fx.Interpreter):
meta = _map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta
# get byte size for each element
size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size()
# compute the total size of activation tensors
total_activation_size = _compute_node_numel(n.meta['tensor_meta'])
# compute the total size of model parameters
total_activation_size = 0
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
if not getattr(target_module, 'inplace', False):
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
for param in target_module.parameters():
total_param_size += param.numel()
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
# compute the total memory cost of activation tensors and model parameters
total_activation_size *= size_per_elem_bytes
total_param_size *= size_per_elem_bytes
# TODO: node.node_size is not an original attribute
setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)