mirror of https://github.com/hpcaitech/ColossalAI
[fx] Modify solver linearize and add corresponding test (#1531)
* [fx] modify solver linearize and add test * [fx] add torch11 test of linearize but skip it * [fx] remove some unused importspull/1538/head^2
parent
7dc53237c3
commit
56159049e8
|
@ -167,8 +167,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
use_reentrant = True
|
||||
non_leaf_input = 0
|
||||
for var in input_vars[label]:
|
||||
input_node = [item for item in node_list if item.name == var]
|
||||
input_node = input_node[0]
|
||||
input_node = next(item for item in node_list if item.name == var)
|
||||
if input_node.op != "placeholder":
|
||||
non_leaf_input = 1
|
||||
for user in input_node.users:
|
||||
|
|
|
@ -114,7 +114,7 @@ def _discretize(mem_unit, values):
|
|||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain:
|
||||
def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
bwd_time = []
|
||||
|
@ -122,22 +122,22 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
|
|||
x_sizes = [data.numel() * data.element_size()]
|
||||
|
||||
# currently we can't get the temp memory needed in fwd and bwd
|
||||
tmp_fwd = [0] * len(node_dict)
|
||||
tmp_bwd = [0] * (len(node_dict) + 1)
|
||||
tmp_fwd = [0] * len(node_list)
|
||||
tmp_bwd = [0] * (len(node_list) + 1)
|
||||
|
||||
for key in node_dict.keys():
|
||||
for idx, node in enumerate(node_list):
|
||||
fwd_time.append(0)
|
||||
bwd_time.append(0)
|
||||
xbar_sizes.append(0)
|
||||
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
|
||||
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
|
||||
for node in node_dict[key]:
|
||||
fwd_time[-1] += max(node.__flops__, 1)
|
||||
x_sizes.append(node[-1].meta['tensor_meta'].numel *
|
||||
torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size())
|
||||
for n in node:
|
||||
fwd_time[-1] += max(n.__flops__, 1)
|
||||
|
||||
# currently we haven't patched the backward flops count
|
||||
bwd_time[-1] += max(node.__flops__ * 2, 2)
|
||||
bwd_time[-1] += max(n.__flops__ * 2, 2)
|
||||
|
||||
xbar_sizes[-1] += node.__activation__
|
||||
xbar_sizes[-1] += n.__activation__
|
||||
|
||||
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
||||
|
||||
|
@ -151,14 +151,14 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
|
|||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||
|
||||
|
||||
def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule:
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = [op for op in op_list if isinstance(op, Loss)][0]
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
for idx, op in enumerate(op_list, 1):
|
||||
for idx, op in enumerate(op_list, 0):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
@ -166,16 +166,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
|
|||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for node in node_dict[node_idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for node in node_dict[node_idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
@ -199,13 +199,13 @@ def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_sl
|
|||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
"""
|
||||
|
||||
node_dict = linearize(gm)
|
||||
node_list = linearize(gm)
|
||||
mem_unit = mem_limit // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_dict, data, mem_unit)
|
||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_dict)
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
|
|
|
@ -1,89 +1,44 @@
|
|||
from typing import OrderedDict
|
||||
from torch.fx import GraphModule
|
||||
from collections import OrderedDict
|
||||
import pdb
|
||||
from typing import List
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
|
||||
def linearize(gm: GraphModule) -> dict:
|
||||
status_dict = {}
|
||||
node_dict = OrderedDict()
|
||||
node_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
last_dict_len = len(status_dict)
|
||||
# remove node from users list in status_dict
|
||||
for item in status_dict.values():
|
||||
if node in item:
|
||||
item.remove(node)
|
||||
def linearize(gm: GraphModule) -> List[List[Node]]:
|
||||
"""Linearizing the graph
|
||||
|
||||
# pop node from status_dict if it is fully used
|
||||
for key in list(status_dict):
|
||||
if len(status_dict[key]) == 0:
|
||||
status_dict.pop(key)
|
||||
Args:
|
||||
gm (GraphModule): GraphModule derived by tracing
|
||||
|
||||
# first node in graph, it should be in n0-n1 type,
|
||||
# where n0 contains only input op, i.e. placeholder
|
||||
if last_dict_len == 0:
|
||||
node_dict[node_idx] = [node]
|
||||
status_dict[node.name] = list(node.users)
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
Returns:
|
||||
List[List[Node]]: List of list, each inside list of Node presents
|
||||
the actual 'node' in linearized manner.
|
||||
"""
|
||||
|
||||
continue
|
||||
def _is_sink() -> bool:
|
||||
"""Check if we can free all dependencies
|
||||
|
||||
# boundary case
|
||||
if len(status_dict) == 0:
|
||||
# current node region end point = next node region start point
|
||||
# i.e. n1-n2-n3-... type node, each node contains only one op
|
||||
if last_dict_len == 1:
|
||||
if len(node_dict[node_idx]) > 0:
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
continue
|
||||
return not sum([v for _, v in deps.items()])
|
||||
|
||||
# n1-n2-n3, if n1 has multiple ops, the last op in n1 will be
|
||||
# the one who is able to clean all others in status_dict
|
||||
# and as the last_dict_len > 1, there are multiple ops are used
|
||||
# by this node, we view it as the end of one node and start a new node
|
||||
else:
|
||||
deps = {}
|
||||
linearized_nodes = []
|
||||
region = []
|
||||
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
for n in gm.graph.nodes:
|
||||
for n_par in n._input_nodes:
|
||||
deps[n_par] -= 1
|
||||
region.append(n)
|
||||
|
||||
continue
|
||||
# if the node could free all dependencies in graph
|
||||
# we could begin a new node
|
||||
if _is_sink():
|
||||
linearized_nodes.append(region)
|
||||
region = []
|
||||
|
||||
else:
|
||||
# currently I will use bigger node structure
|
||||
# if the following region is activated, the node will be smaller
|
||||
#################################################
|
||||
# if last_dict_len == 1:
|
||||
# if len(node_dict[node_idx]) > 0:
|
||||
# node_idx += 1
|
||||
# node_dict[node_idx] = [node]
|
||||
# status_dict[node.name] = list(node.users)
|
||||
#
|
||||
# continue
|
||||
#################################################
|
||||
deps[n] = len(n.users)
|
||||
|
||||
# in-node case, as the current node can not clean status_dict
|
||||
# we view it as in-node status, the node will be appended to the
|
||||
# current node_idx
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
|
||||
continue
|
||||
|
||||
# If the output node use multiple nodes, there might be an
|
||||
# empty node after the output node
|
||||
if len(node_dict[node_idx]) == 0:
|
||||
node_dict.pop[node_idx]
|
||||
node_idx -= 1
|
||||
|
||||
# pop the last two nodes
|
||||
node_dict.pop(0)
|
||||
node_dict.pop(node_idx)
|
||||
return node_dict
|
||||
# Remove input
|
||||
linearized_nodes = linearized_nodes[1:-1]
|
||||
return linearized_nodes
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
for M, budgets in MODEL_DICT.items():
|
||||
for budget in budgets:
|
||||
model = M()
|
||||
graph = tracer.trace(model)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardEnable):
|
||||
for n in node_list[idx]:
|
||||
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
|
||||
in_ckpt = False
|
||||
|
||||
ckpt_idx += 1
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
del node_list
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
for M, budgets in MODEL_DICT.items():
|
||||
for budget in budgets:
|
||||
model = M()
|
||||
graph = tracer.trace(model)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardEnable):
|
||||
for n in node_list[idx]:
|
||||
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
|
||||
in_ckpt = False
|
||||
|
||||
ckpt_idx += 1
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
del node_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_linearize()
|
Loading…
Reference in New Issue