[fx] Modify solver linearize and add corresponding test (#1531)

* [fx] modify solver linearize and add test

* [fx] add torch11 test of linearize but skip it

* [fx] remove some unused imports
pull/1538/head^2
Boyuan Yao 2022-09-02 10:24:41 +08:00 committed by GitHub
parent 7dc53237c3
commit 56159049e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 181 additions and 99 deletions

View File

@ -167,8 +167,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
use_reentrant = True
non_leaf_input = 0
for var in input_vars[label]:
input_node = [item for item in node_list if item.name == var]
input_node = input_node[0]
input_node = next(item for item in node_list if item.name == var)
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:

View File

@ -114,7 +114,7 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain:
def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain:
fwd_time = []
bwd_time = []
@ -122,22 +122,22 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes = [data.numel() * data.element_size()]
# currently we can't get the temp memory needed in fwd and bwd
tmp_fwd = [0] * len(node_dict)
tmp_bwd = [0] * (len(node_dict) + 1)
tmp_fwd = [0] * len(node_list)
tmp_bwd = [0] * (len(node_list) + 1)
for key in node_dict.keys():
for idx, node in enumerate(node_list):
fwd_time.append(0)
bwd_time.append(0)
xbar_sizes.append(0)
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
for node in node_dict[key]:
fwd_time[-1] += max(node.__flops__, 1)
x_sizes.append(node[-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size())
for n in node:
fwd_time[-1] += max(n.__flops__, 1)
# currently we haven't patched the backward flops count
bwd_time[-1] += max(node.__flops__ * 2, 2)
bwd_time[-1] += max(n.__flops__ * 2, 2)
xbar_sizes[-1] += node.__activation__
xbar_sizes[-1] += n.__activation__
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
@ -151,14 +151,14 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule:
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
op_list = sequence.list_operations()
loss_op = [op for op in op_list if isinstance(op, Loss)][0]
loss_op = next(op for op in op_list if isinstance(op, Loss))
op_list = op_list[:op_list.index(loss_op)]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
for idx, op in enumerate(op_list, 1):
for idx, op in enumerate(op_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
@ -166,16 +166,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = [idx]
@ -199,13 +199,13 @@ def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_sl
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize(gm)
node_list = linearize(gm)
mem_unit = mem_limit // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_dict, data, mem_unit)
chain: Chain = _construct_chain(node_list, data, mem_unit)
opt_table = _compute_table(chain, mem_slots)
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_dict)
_annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)

View File

@ -1,89 +1,44 @@
from typing import OrderedDict
from torch.fx import GraphModule
from collections import OrderedDict
import pdb
from typing import List
from torch.fx import GraphModule, Node
def linearize(gm: GraphModule) -> dict:
status_dict = {}
node_dict = OrderedDict()
node_idx = 0
for node in gm.graph.nodes:
last_dict_len = len(status_dict)
# remove node from users list in status_dict
for item in status_dict.values():
if node in item:
item.remove(node)
def linearize(gm: GraphModule) -> List[List[Node]]:
"""Linearizing the graph
# pop node from status_dict if it is fully used
for key in list(status_dict):
if len(status_dict[key]) == 0:
status_dict.pop(key)
Args:
gm (GraphModule): GraphModule derived by tracing
# first node in graph, it should be in n0-n1 type,
# where n0 contains only input op, i.e. placeholder
if last_dict_len == 0:
node_dict[node_idx] = [node]
status_dict[node.name] = list(node.users)
node_idx += 1
node_dict[node_idx] = []
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
"""
continue
def _is_sink() -> bool:
"""Check if we can free all dependencies
# boundary case
if len(status_dict) == 0:
# current node region end point = next node region start point
# i.e. n1-n2-n3-... type node, each node contains only one op
if last_dict_len == 1:
if len(node_dict[node_idx]) > 0:
node_idx += 1
node_dict[node_idx] = []
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
Returns:
bool
"""
continue
return not sum([v for _, v in deps.items()])
# n1-n2-n3, if n1 has multiple ops, the last op in n1 will be
# the one who is able to clean all others in status_dict
# and as the last_dict_len > 1, there are multiple ops are used
# by this node, we view it as the end of one node and start a new node
else:
deps = {}
linearized_nodes = []
region = []
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
node_idx += 1
node_dict[node_idx] = []
for n in gm.graph.nodes:
for n_par in n._input_nodes:
deps[n_par] -= 1
region.append(n)
continue
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
else:
# currently I will use bigger node structure
# if the following region is activated, the node will be smaller
#################################################
# if last_dict_len == 1:
# if len(node_dict[node_idx]) > 0:
# node_idx += 1
# node_dict[node_idx] = [node]
# status_dict[node.name] = list(node.users)
#
# continue
#################################################
deps[n] = len(n.users)
# in-node case, as the current node can not clean status_dict
# we view it as in-node status, the node will be appended to the
# current node_idx
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
continue
# If the output node use multiple nodes, there might be an
# empty node after the output node
if len(node_dict[node_idx]) == 0:
node_dict.pop[node_idx]
node_idx -= 1
# pop the last two nodes
node_dict.pop(0)
node_dict.pop(node_idx)
return node_dict
# Remove input
linearized_nodes = linearized_nodes[1:-1]
return linearized_nodes

View File

@ -0,0 +1,128 @@
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
for M, budgets in MODEL_DICT.items():
for budget in budgets:
model = M()
graph = tracer.trace(model)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
node_list = linearize(gm)
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
op_list = gm.__sequence__.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
op_list = op_list[:op_list.index(loss_op)]
in_ckpt = False
ckpt_idx = 0
for idx, op in enumerate(op_list):
if in_ckpt:
if isinstance(op, ForwardNograd):
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
continue
if isinstance(op, ForwardEnable):
for n in node_list[idx]:
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
in_ckpt = False
ckpt_idx += 1
continue
if isinstance(op, ForwardCheck):
ckpt_idx += 1
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
continue
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
del model
del gm
del node_list
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
for M, budgets in MODEL_DICT.items():
for budget in budgets:
model = M()
graph = tracer.trace(model)
gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
node_list = linearize(gm)
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
op_list = gm.__sequence__.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
op_list = op_list[:op_list.index(loss_op)]
in_ckpt = False
ckpt_idx = 0
for idx, op in enumerate(op_list):
if in_ckpt:
if isinstance(op, ForwardNograd):
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
continue
if isinstance(op, ForwardEnable):
for n in node_list[idx]:
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
in_ckpt = False
ckpt_idx += 1
continue
if isinstance(op, ForwardCheck):
ckpt_idx += 1
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
continue
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
del model
del gm
del node_list
if __name__ == "__main__":
test_linearize()