[hotfix/rotor] fix variable names (#1597)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.

* [fx] fix variable names in ckpt_solvers and unskip tests.

* [fx] commit my changes.

* [fx] restore skips.

* [fx] restore skips.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.
pull/1604/head
Super Daniel 2022-09-14 14:27:04 +08:00 committed by GitHub
parent faa23b9d9a
commit c8e9b2ad78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 85 additions and 83 deletions

View File

@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'fwd_out')
n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, 'fwd_out')
x += n.meta['fwd_mem_out']
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1

View File

@ -1,11 +1,11 @@
from typing import List, Set, Tuple, Dict
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math
from .linearize import linearize
from .utils import *
from colossalai.fx.profiler import profile_function, profile_module
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar
fwd_tmp = chain.fwd_tmp + [0]
bwd_tmp = chain.bwd_tmp + [0]
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1):
idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin:
opt[m][i][idx] = float("inf")
else:
@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0
for n in node:
xbar += n.fwd_tmp + n.fwd_out
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
return xbar
@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.fwd_flop, 1)
fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time
@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.bwd_flop, 1)
bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time
def _get_bwd_tmp(node: List[Node]) -> int:
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
Args:
@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:
def _get_deps_size():
deps_size = 0
for key in deps.keys():
deps_size += key.bwd_out
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']
return deps_size
bwd_tmp = 0
bwd_mem_tmp = 0
deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for son in node[-1].users:
deps[son] = 1
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
deps[n] = len(n._input_nodes)
for son in n.users:
deps[son] -= 1
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1
for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
return bwd_tmp
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
"""
node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data, mem_unit)

View File

@ -5,24 +5,24 @@ class Chain:
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_tmp = ftmp
self.bwd_tmp = btmp
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append(
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()

View File

@ -94,12 +94,11 @@ class MetaInfoProp(torch.fx.Interpreter):
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result)
# retain the autograd graph

View File

@ -1,11 +1,12 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
class Stage(Enum):
class Phase(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
@ -48,24 +49,9 @@ class GraphInfo:
bwd_mem_out: int = 0
def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD
def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS
def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER
def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == phase
def is_saved(n: Node):
@ -74,7 +60,7 @@ def is_saved(n: Node):
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Basically the input graph should have all nodes marked for keyword `phase`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""
def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0
peak_mem = 0
for k, v in deps.items():
if v > 0:
bwd_tmp += activation_size(k.meta['out'])
return bwd_tmp
peak_mem += activation_size(k.meta['out'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
deps = {}
@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(is_loss, n.users)):
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`.
# Any `fwd_in` should be kept in memory even this function
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed.
if is_placeholder(n):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n):
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n):
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)

View File

@ -5,8 +5,8 @@ import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, Stage
from .memory import WEIRD_OPS
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS, activation_size
from .tensor import MetaTensor
from .opcount import flop_mapping
@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Stage.FORWARD: 0,
Stage.LOSS: 0,
Stage.BACKWARD: 0,
Phase.FORWARD: 0,
Phase.LOSS: 0,
Phase.BACKWARD: 0,
}
# `stage` will mark the stage of autograd from outside scope.
stage = Stage.FORWARD
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out)
node.meta['stage'] = stage
node.meta['phase'] = phase
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['stage'] = Stage.PLACEHOLDER
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['out'] = (x._tensor,)
tree_map(set_placeholder, args)
@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
def unpack(x):
return x
# `phase` will mark the phase of autograd from outside scope.
phase = Phase.FORWARD
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
stage = Stage.LOSS
phase = Phase.LOSS
loss = out.sum()
stage = Stage.BACKWARD
phase = Phase.BACKWARD
loss.backward()
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable:
# If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling.
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, **kwargs)
return out, meta
@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable:
# If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling.
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, meta

View File

@ -38,7 +38,8 @@ def test_linearize():
if isinstance(op, ForwardNograd):
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
continue
@ -54,7 +55,8 @@ def test_linearize():
ckpt_idx += 1
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
continue
@ -63,7 +65,8 @@ def test_linearize():
in_ckpt = True
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
del model
del gm