mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix MetaInfoProp for incorrect calculations and add detections for inplace op. (#1466)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands.pull/1467/head
parent
e7383f578b
commit
bbc58d881b
|
@ -36,20 +36,20 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
|||
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
|
||||
|
||||
|
||||
def _compute_node_numel(node_metadata: any) -> int:
|
||||
def _compute_activation_size(node_metadata: any) -> int:
|
||||
"""
|
||||
Compute numel of a node with ``tensor_meta`` attribute.
|
||||
"""
|
||||
node_numel = 0
|
||||
|
||||
if isinstance(node_metadata, TensorMetadata):
|
||||
node_numel += node_metadata.numel
|
||||
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
|
||||
elif isinstance(node_metadata, dict):
|
||||
value_list = [v for _, v in node_metadata.items()]
|
||||
node_numel += _compute_node_numel(value_list)
|
||||
node_numel += _compute_activation_size(value_list)
|
||||
else:
|
||||
for element in node_metadata:
|
||||
node_numel += _compute_node_numel(element)
|
||||
node_numel += _compute_activation_size(element)
|
||||
|
||||
return node_numel
|
||||
|
||||
|
@ -105,6 +105,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
"""
|
||||
|
||||
def run_node(self, n: Node) -> Any:
|
||||
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
|
||||
result = super().run_node(n)
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
|
@ -116,24 +117,20 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
meta = _map_aggregate(result, extract_tensor_meta)
|
||||
n.meta['tensor_meta'] = meta
|
||||
|
||||
# get byte size for each element
|
||||
size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size()
|
||||
|
||||
# compute the total size of activation tensors
|
||||
total_activation_size = _compute_node_numel(n.meta['tensor_meta'])
|
||||
|
||||
# compute the total size of model parameters
|
||||
total_activation_size = 0
|
||||
total_param_size = 0
|
||||
if n.op == 'call_module':
|
||||
target_module = n.graph.owning_module.get_submodule(n.target)
|
||||
if not getattr(target_module, 'inplace', False):
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
for param in target_module.parameters():
|
||||
total_param_size += param.numel()
|
||||
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
elif n.op == 'call_function':
|
||||
if 'inplace' not in n.kwargs:
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
else:
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
|
||||
# compute the total memory cost of activation tensors and model parameters
|
||||
total_activation_size *= size_per_elem_bytes
|
||||
total_param_size *= size_per_elem_bytes
|
||||
|
||||
# TODO: node.node_size is not an original attribute
|
||||
setattr(n, 'node_size', total_activation_size + total_param_size)
|
||||
setattr(n, 'param_size', total_param_size)
|
||||
setattr(n, 'activation_size', total_activation_size)
|
||||
|
|
Loading…
Reference in New Issue