2022-09-14 01:36:43 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from enum import Enum
|
|
|
|
from typing import Dict
|
|
|
|
from torch.fx import Graph, Node
|
2022-09-15 06:46:36 +00:00
|
|
|
from .memory import activation_size, is_inplace
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
|
2022-09-14 06:27:04 +00:00
|
|
|
class Phase(Enum):
|
2022-09-14 01:36:43 +00:00
|
|
|
FORWARD = 0
|
2022-09-15 06:46:36 +00:00
|
|
|
BACKWARD = 1
|
|
|
|
PLACEHOLDER = 2
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GraphInfo:
|
|
|
|
"""
|
|
|
|
GraphInfo is a dataclass for MetaInfo, which measures
|
|
|
|
the execution memory cost and FLOPs with `MetaTensor`.
|
|
|
|
The dataflow analysis is conducted on a single node of the FX graph.
|
|
|
|
============================================================================
|
|
|
|
-------------------------------
|
|
|
|
| Node |
|
2022-09-23 02:59:47 +00:00
|
|
|
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
|
2022-09-14 01:36:43 +00:00
|
|
|
placeholders saved for | | \__________ | |
|
|
|
|
backward. | | \ | |
|
|
|
|
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
|
|
|
| | \_________ | | [bwd_tmp] marks the peak memory
|
|
|
|
| / \ \ | | in backward pass.
|
|
|
|
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
2022-09-23 02:59:47 +00:00
|
|
|
in [fwd_tmp] because | | \_____ | |
|
|
|
|
it is not saved for | | \ | |
|
|
|
|
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
|
|
|
|
-------------------------------
|
2022-09-14 01:36:43 +00:00
|
|
|
============================================================================
|
|
|
|
Attributes:
|
|
|
|
fwd_flop (int): The forward FLOPs of a certain node
|
|
|
|
bwd_flop (int): The backward FLOPs of a certain node.
|
2022-09-23 02:59:47 +00:00
|
|
|
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
2022-09-14 01:36:43 +00:00
|
|
|
fwd_mem_tmp (int): See the above illustration.
|
2022-09-23 02:59:47 +00:00
|
|
|
fwd_mem_out (int): See the above illustration.
|
2022-09-14 01:36:43 +00:00
|
|
|
bwd_mem_tmp (int): See the above illustration.
|
|
|
|
bwd_mem_out (int): See the above illustration.
|
|
|
|
"""
|
|
|
|
fwd_flop: int = 0
|
|
|
|
bwd_flop: int = 0
|
2022-09-23 02:59:47 +00:00
|
|
|
save_fwd_in: bool = False
|
2022-09-14 01:36:43 +00:00
|
|
|
fwd_mem_tmp: int = 0
|
2022-09-23 02:59:47 +00:00
|
|
|
fwd_mem_out: int = 0
|
2022-09-14 01:36:43 +00:00
|
|
|
bwd_mem_tmp: int = 0
|
|
|
|
bwd_mem_out: int = 0
|
|
|
|
|
|
|
|
|
2022-09-14 06:27:04 +00:00
|
|
|
def is_phase(n: Node, phase: Phase) -> bool:
|
|
|
|
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
|
|
|
|
return n.meta['phase'] == phase
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def is_saved(n: Node):
|
2022-09-23 02:59:47 +00:00
|
|
|
return len(n.meta['saved_tensor'])
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|
|
|
"""Analyze the autograd node dependencies and find out the memory usage.
|
2022-09-14 06:27:04 +00:00
|
|
|
Basically the input graph should have all nodes marked for keyword `phase`.
|
2022-09-14 01:36:43 +00:00
|
|
|
Nodes should have attribute `out` indicating the output of each node.
|
|
|
|
============================================================================
|
|
|
|
Placeholder ----> p o <---- We need to keep track of grad out
|
|
|
|
|\________ |
|
|
|
|
↓ ↘|
|
|
|
|
f --------> b
|
|
|
|
|\ \_____ ↑
|
|
|
|
| \ ↘ /
|
|
|
|
f f ----> b <---- Not every forward result needs to be saved for backward
|
|
|
|
| \____ ↑
|
|
|
|
↘ ↘|
|
|
|
|
f ----> b <---- Backward can be freed as soon as it is required no more.
|
|
|
|
↘ ↗
|
|
|
|
l
|
|
|
|
=============================================================================
|
|
|
|
Args:
|
2022-09-14 06:27:04 +00:00
|
|
|
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
graph_info (GraphInfo): Meta information for the dataflow.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _peak_memory(deps: Dict[Node, int]):
|
2022-09-14 06:27:04 +00:00
|
|
|
peak_mem = 0
|
2022-09-14 01:36:43 +00:00
|
|
|
for k, v in deps.items():
|
2022-09-23 02:59:47 +00:00
|
|
|
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
|
|
|
|
peak_mem += activation_size(k.meta['saved_tensor'])
|
|
|
|
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
|
|
|
|
peak_mem -= activation_size(k.meta['saved_tensor'])
|
2022-09-14 06:27:04 +00:00
|
|
|
return peak_mem
|
2022-09-14 01:36:43 +00:00
|
|
|
|
|
|
|
# deps is used to track all the memory dependencies of the graph.
|
|
|
|
deps = {}
|
|
|
|
graph_info = GraphInfo()
|
|
|
|
|
|
|
|
for n in graph.nodes:
|
|
|
|
n: Node
|
2022-09-23 02:59:47 +00:00
|
|
|
deps[n] = len(n.users)
|
|
|
|
# A forward tensor who is marked `save` but is also
|
|
|
|
# an input to `Phase.FORWARD` should be saved during forward.
|
|
|
|
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
|
|
|
# Any `fwd_mem_in` should be kept in memory even this function
|
|
|
|
# is checkpointed.
|
|
|
|
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
|
|
|
# the node, `fwd_mem_tmp` can be freed.
|
|
|
|
if is_phase(n, Phase.PLACEHOLDER):
|
|
|
|
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
|
|
|
if is_phase(n, Phase.FORWARD):
|
|
|
|
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
2022-09-14 06:27:04 +00:00
|
|
|
elif is_phase(n, Phase.BACKWARD):
|
2022-09-14 01:36:43 +00:00
|
|
|
if len(n.users):
|
|
|
|
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
|
|
|
else:
|
2022-09-15 06:46:36 +00:00
|
|
|
# TODO: some of the bwd_mem_out might be model parameters.
|
2022-09-14 01:36:43 +00:00
|
|
|
# basically a backward node without user is a `grad_out` node
|
2022-09-23 02:59:47 +00:00
|
|
|
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
|
2022-09-15 06:46:36 +00:00
|
|
|
for input_n in n.all_input_nodes:
|
|
|
|
if input_n in deps:
|
|
|
|
deps[input_n] -= 1
|
|
|
|
if deps[input_n] <= 0:
|
|
|
|
deps[input_n] = float('-inf')
|
2022-09-14 01:36:43 +00:00
|
|
|
return graph_info
|