[fx] Improve linearize and rotor solver (#1586)

* [fx] add nested activation_checkpoint codegen

* undo algorithms commits

* solver

* undo some commits

* [fx] torch11 add nested activation checkpoint codegen

* remove some imports

* [fx] add some comments in activation codegen

* [fx] codegen instance error fix

* [fx] imporve linearize and rotor solver

* [fx] some comments and format modification
pull/1589/head
Boyuan Yao 2022-09-13 14:50:04 +08:00 committed by GitHub
parent 219f66c571
commit 49ccf8b5f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 201 additions and 39 deletions

View File

@ -7,6 +7,7 @@ from .linearize import linearize
from .utils import *
from colossalai.fx.profiler import profile_function, profile_module
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
# this is the python compute table code from rotor
@ -36,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
@ -151,6 +152,97 @@ def _get_inplace(node: Node) -> bool:
return is_inplace
def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: xbar size, unit Byte
"""
xbar = 0
for n in node:
xbar += n.fwd_tmp + n.fwd_out
return xbar
def _fwd_time(node: List[Node]) -> int:
"""Get the foward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: foward time, extimated by flops count
"""
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.fwd_flop, 1)
return fwd_time
def _bwd_time(node: List[Node]) -> int:
"""Get the backward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward time, extimated by flops count
"""
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.bwd_flop, 1)
return bwd_time
def _get_bwd_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward temp memory, unit Byte
"""
def _get_deps_size():
deps_size = 0
for key in deps.keys():
deps_size += key.bwd_out
return deps_size
bwd_tmp = 0
deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for son in node[-1].users:
deps[son] = 1
for n in reversed(node):
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
deps[n] = len(n._input_nodes)
for son in n.users:
deps[son] -= 1
for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
return bwd_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
fwd_time = []
@ -160,45 +252,32 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
xbar_sizes = [_compute_size(data)]
x_sizes = [_compute_size(data)]
elif isinstance(data, list) or isinstance(data, tuple):
xbar_sizes = [_compute_size(obj) for obj in data]
x_sizes = [_compute_size(obj) for obj in data]
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
x_sizes = [sum([_compute_size(obj) for obj in data])]
elif isinstance(data, dict):
xbar_sizes = [_compute_size(obj) for obj in data.values()]
x_sizes = [_compute_size(obj) for obj in data.values()]
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
# currently we can't get the temp memory needed in fwd and bwd
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_bwd = [0] * (len(node_list) + 1)
tmp_bwd = []
for idx, node in enumerate(node_list):
fwd_time.append(0)
bwd_time.append(0)
xbar_sizes.append(0)
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_tmp(node))
_check_inplace_flag = 1
for n in node:
fwd_time[-1] += max(n.__flops__, 1)
# currently we haven't patched the backward flops count
bwd_time[-1] += max(n.__flops__ * 2, 2)
xbar_sizes[-1] += n.__activation__
# we need to clear the xbar of previous node as there is
# one op in the current node that use the previous node's
# output but applies inplace operation on it
# NOTE: This process should be done only once as the previous
# node will only have one output
if _check_inplace_flag:
for par in n._input_nodes:
if par not in node and _get_inplace(n):
xbar_sizes[-2] -= x_sizes[-2]
_check_inplace_flag = 0
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
xbar_sizes[-1] = 0
bwd_time.append(0)
# currently we view loss backward temp as zero
tmp_bwd.append(0)
xbar_sizes = _discretize(mem_unit, xbar_sizes)
x_sizes = _discretize(mem_unit, x_sizes)
tmp_fwd = _discretize(mem_unit, tmp_fwd)
@ -207,14 +286,17 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
op_list = op_list[:op_list.index(loss_op)]
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
for idx, op in enumerate(op_list, 0):
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
@ -223,7 +305,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
@ -231,7 +313,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [idx]
@ -241,12 +323,62 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
def solver_rotor(gm: ColoGraphModule,
data,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None) -> ColoGraphModule:
cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
@ -255,13 +387,14 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.02
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_list = linearize(gm, cnode)
mem_unit = mem_limit // mem_slots
mem_unit = mem_limit * (1.0 - eps) // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data, mem_unit)
opt_table = _compute_table(chain, mem_slots)

View File

@ -1,6 +1,35 @@
from typing import List
from typing import List, Any
from torch.fx import GraphModule, Node
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in COPS
else:
return target.__name__ in COPS
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
"""Linearizing the graph
@ -53,7 +82,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
region = []
# propagate common node attr if possible
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])