mirror of https://github.com/hpcaitech/ColossalAI
[fx] Improve linearize and rotor solver (#1586)
* [fx] add nested activation_checkpoint codegen * undo algorithms commits * solver * undo some commits * [fx] torch11 add nested activation checkpoint codegen * remove some imports * [fx] add some comments in activation codegen * [fx] codegen instance error fix * [fx] imporve linearize and rotor solver * [fx] some comments and format modificationpull/1589/head
parent
219f66c571
commit
49ccf8b5f8
|
@ -7,6 +7,7 @@ from .linearize import linearize
|
|||
from .utils import *
|
||||
from colossalai.fx.profiler import profile_function, profile_module
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
|
@ -36,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||
for m in range(mmax + 1):
|
||||
for i in range(chain.length + 1):
|
||||
#lmax-lmin = 0
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||
if m >= limit: ## Equation (1)
|
||||
opt[m][i][i] = fw[i] + bw[i]
|
||||
else:
|
||||
|
@ -151,6 +152,97 @@ def _get_inplace(node: Node) -> bool:
|
|||
return is_inplace
|
||||
|
||||
|
||||
def _fwd_xbar(node: List[Node]) -> int:
|
||||
"""Get the forward xbar of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: xbar size, unit Byte
|
||||
"""
|
||||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.fwd_tmp + n.fwd_out
|
||||
return xbar
|
||||
|
||||
|
||||
def _fwd_time(node: List[Node]) -> int:
|
||||
"""Get the foward time of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: foward time, extimated by flops count
|
||||
"""
|
||||
|
||||
fwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
fwd_time += max(n.fwd_flop, 1)
|
||||
return fwd_time
|
||||
|
||||
|
||||
def _bwd_time(node: List[Node]) -> int:
|
||||
"""Get the backward time of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: backward time, extimated by flops count
|
||||
"""
|
||||
|
||||
bwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
bwd_time += max(n.bwd_flop, 1)
|
||||
return bwd_time
|
||||
|
||||
|
||||
def _get_bwd_tmp(node: List[Node]) -> int:
|
||||
"""Get the backward temp memory of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: backward temp memory, unit Byte
|
||||
"""
|
||||
|
||||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for key in deps.keys():
|
||||
deps_size += key.bwd_out
|
||||
|
||||
return deps_size
|
||||
|
||||
bwd_tmp = 0
|
||||
deps = {}
|
||||
|
||||
# add all the users for last node into deps,
|
||||
# as those nodes' gradient out will be stored in memory
|
||||
for son in node[-1].users:
|
||||
deps[son] = 1
|
||||
for n in reversed(node):
|
||||
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
|
||||
deps[n] = len(n._input_nodes)
|
||||
for son in n.users:
|
||||
deps[son] -= 1
|
||||
|
||||
for key in list(deps.keys()):
|
||||
if deps[key] == 0:
|
||||
del deps[key]
|
||||
|
||||
return bwd_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
|
@ -160,45 +252,32 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
|||
xbar_sizes = [_compute_size(data)]
|
||||
x_sizes = [_compute_size(data)]
|
||||
elif isinstance(data, list) or isinstance(data, tuple):
|
||||
xbar_sizes = [_compute_size(obj) for obj in data]
|
||||
x_sizes = [_compute_size(obj) for obj in data]
|
||||
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||
x_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||
elif isinstance(data, dict):
|
||||
xbar_sizes = [_compute_size(obj) for obj in data.values()]
|
||||
x_sizes = [_compute_size(obj) for obj in data.values()]
|
||||
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||
|
||||
# currently we can't get the temp memory needed in fwd and bwd
|
||||
# currently we can't get the temp memory needed in fwd
|
||||
tmp_fwd = [0] * len(node_list)
|
||||
tmp_bwd = [0] * (len(node_list) + 1)
|
||||
tmp_bwd = []
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
fwd_time.append(0)
|
||||
bwd_time.append(0)
|
||||
xbar_sizes.append(0)
|
||||
fwd_time.append(_fwd_time(node))
|
||||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(_compute_output_size(node))
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_bwd.append(_get_bwd_tmp(node))
|
||||
|
||||
_check_inplace_flag = 1
|
||||
for n in node:
|
||||
fwd_time[-1] += max(n.__flops__, 1)
|
||||
|
||||
# currently we haven't patched the backward flops count
|
||||
bwd_time[-1] += max(n.__flops__ * 2, 2)
|
||||
xbar_sizes[-1] += n.__activation__
|
||||
|
||||
# we need to clear the xbar of previous node as there is
|
||||
# one op in the current node that use the previous node's
|
||||
# output but applies inplace operation on it
|
||||
# NOTE: This process should be done only once as the previous
|
||||
# node will only have one output
|
||||
if _check_inplace_flag:
|
||||
for par in n._input_nodes:
|
||||
if par not in node and _get_inplace(n):
|
||||
xbar_sizes[-2] -= x_sizes[-2]
|
||||
_check_inplace_flag = 0
|
||||
|
||||
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
||||
# if a node with only one inplace op, we need to let x_bar = 0
|
||||
if len(node) == 1 and _get_inplace(node[0]):
|
||||
xbar_sizes[-1] = 0
|
||||
|
||||
bwd_time.append(0)
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
tmp_bwd.append(0)
|
||||
|
||||
xbar_sizes = _discretize(mem_unit, xbar_sizes)
|
||||
x_sizes = _discretize(mem_unit, x_sizes)
|
||||
tmp_fwd = _discretize(mem_unit, tmp_fwd)
|
||||
|
@ -207,14 +286,17 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
|||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||
|
||||
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
for idx, op in enumerate(op_list, 0):
|
||||
|
||||
# forward annotation
|
||||
for idx, op in enumerate(fwd_list, 0):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
@ -223,7 +305,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
@ -231,7 +313,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
@ -241,12 +323,62 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||
in_ckpt = True
|
||||
ckpt_region.append(idx)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
||||
|
||||
|
||||
def solver_rotor(gm: ColoGraphModule,
|
||||
data,
|
||||
mem_limit: int,
|
||||
mem_slots: int = 500,
|
||||
cnode: List[str] = None) -> ColoGraphModule:
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.02) -> ColoGraphModule:
|
||||
"""solver that automatically find activation checkpoint in rotor's manner
|
||||
|
||||
Args:
|
||||
|
@ -255,13 +387,14 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
mem_limit (int): memory budget in Byte.
|
||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||
eps (float): epsilon for memory decay. Defaults to 0.02
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_unit = mem_limit // mem_slots
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
|
|
|
@ -1,6 +1,35 @@
|
|||
from typing import List
|
||||
from typing import List, Any
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
# unchanged throughout the whole model, it will be used several times by
|
||||
# different blocks of model, so that it is hard for us to linearize the graph
|
||||
# when we encounter those kinds of nodes. We let users to annotate some of the
|
||||
# input as common node, such as attention mask, and the followings are some of
|
||||
# the ops that could actually be seen as common nodes. With our common node prop,
|
||||
# we could find some of the "real" common nodes (e.g. the real attention mask
|
||||
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
||||
# nodes or it's op belongs to the following operations, we view this node as a
|
||||
# newly born common node.
|
||||
# List of target name that could be seen as common node
|
||||
COPS = ["getattr", "getitem", "size"]
|
||||
|
||||
|
||||
def _is_cop(target: Any) -> bool:
|
||||
"""Check if an op could be seen as common node
|
||||
|
||||
Args:
|
||||
target (Any): node target
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
if isinstance(target, str):
|
||||
return target in COPS
|
||||
else:
|
||||
return target.__name__ in COPS
|
||||
|
||||
|
||||
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||
"""Linearizing the graph
|
||||
|
@ -53,7 +82,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
|||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
|
||||
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
|
||||
cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
|
Loading…
Reference in New Issue