[fx] remove depreciated algorithms. (#2312) (#2313)

pull/2987/head
Super Daniel 2023-03-07 10:30:35 +08:00 committed by GitHub
parent 55dcd3051a
commit b42d3d28ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 0 additions and 1970 deletions

View File

@ -1,4 +0,0 @@
from .ckpt_solver_chen import chen_greedy
from .linearize import linearize
from .ckpt_solver_rotor import solver_rotor
from .ckpt_solver_pofo import solver_pofo

View File

@ -1,15 +0,0 @@
from setuptools import setup, Extension
import os
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'dynamic_programs_C_version',
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)

View File

@ -1,98 +0,0 @@
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def is_sink():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return not sum((v for k, v in deps.items()))
deps = {}
ckpt_nodes = []
for n in gm.graph.nodes:
for n_par in n._input_nodes:
deps[n_par] -= 1 # free memory and dependencies
# We can only put act_ckpt on these nodes
if n.op in CKPT_OP and is_sink():
ckpt_nodes.append(n)
deps[n] = len(n.users) # add dependencies for future executions
return ckpt_nodes
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
def grid_search(num_grids: int = 6) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += calculate_fwd_in(n)
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in CKPT_OP:
setattr(n, 'activation_checkpoint', i)
gm.recompile()
return gm

View File

@ -1,537 +0,0 @@
import copy
import math
from typing import List, Tuple
import torch
from colossalai.fx import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import \
_find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import parameter_size
from torch.fx import GraphModule, Node
from .linearize import linearize
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
Sequence)
INF = float("inf")
def _normalize_flops(chain: Chain, flops) -> Chain:
"""
Normalize flops
"""
for i in range(chain.length):
chain.fweight[i] /= flops
chain.bweight[i] /= flops
return chain
class PofoTable:
"""PofoTable
The PofoTable contains the necessary components to store intermediate results
of dynamic programming and the operations alone the way.
"""
def __init__(self, chain_length: int, mem_slots: int):
"""Init pofo table
The pofo table contains two tables, opt and what, indicating values and
operations.
Args:
chain_length (int): chain length
mem_slots (int): number of memory slots
"""
self.length = chain_length
self.mem_slots = mem_slots
# initializing tables
# the first bool indicates whether the input has bar
# opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
# what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
# where is_enable indicates whether we enable the gradient, is_offload indicates whether we
# offload the input, index indicates the end of F_\empty sequence if is_enable = False
self.opt = {
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
}
self.what = {
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
}
def _get_value(self, state, table, default):
i, act_size, df, db, input_has_bar = state
if act_size + df > self.mem_slots or act_size + db > self.mem_slots:
return default
try:
return table[input_has_bar][i][act_size][(df, db)]
except KeyError:
print(f"state not found {state}")
def get_opt(self, state):
return self._get_value(state, self.opt, INF)
def get_what(self, state):
return self._get_value(state, self.what, INF)
def set_value(self, state, opt, what):
i, act_size, df, db, input_has_bar = state
self.opt[input_has_bar][i][act_size][(df, db)] = opt
self.what[input_has_bar][i][act_size][(df, db)] = what
class PofoSolver:
"""PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
"""
def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None:
self.chain = chain
self.length = chain.length
self.max_memory = max_memory
self.mem_slots = mem_slots
self.mem_unit = max_memory / mem_slots
self.bandwidth = bandwidth
self.disc_chain = copy.deepcopy(self.chain)
self.disc_chain._discretize(self.mem_unit)
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
self._compute_pofo_table()
def _discretize(self, *values) -> Tuple:
return tuple(math.ceil(value / self.mem_unit) for value in values)
def _undiscretize(self, *discrete_values) -> Tuple:
if len(discrete_values) == 1:
return discrete_values[0] * self.mem_unit
else:
return tuple(d * self.mem_unit for d in discrete_values)
def _mmax_all(self, idx: int):
"""
Calculate the maximum memory usage of Fi_all
"""
return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx]
def _mmax_b(self, idx: int):
"""
Calculate the maximum memory usage of Bi
"""
return self.chain.cbweight[idx +
1] + self.chain.cweight[idx +
1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx]
def _mmax_ng(self, i: int, j: int):
"""
Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
"""
res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j]
if j > i:
res += self.chain.cweight[j]
return res
def _rotor_estimated_bwd(self, i, j, m, delta):
compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j]
comm = delta / self.bandwidth
return (max(compute, comm) + compute + comm) / 2
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
def _common_values_enable(self, state: Tuple):
idx, act_size, df, db, input_has_bar = state
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
mf = act_size + df + input_size
mb = act_size + db + input_size
mem_avail = self.max_memory - act_size - input_size
f_usage = self._mmax_all(idx)
b_usage = self._mmax_b(idx)
# infeasible
if f_usage > mem_avail or b_usage > mem_avail:
return None
# calculate idle time
eps_f_beta = max(0, f_usage - self.max_memory + mf)
eps_b_beta = max(0, b_usage - self.max_memory + mb)
idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth
# calculate offload and prefetch data
offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta
prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta
# total_time
total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time
return (offload_data, prefetch_data, total_time, idle_time)
def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False):
i, act_size, df, db, input_has_bar = state
# compute new epsilon_tmp and sum_fwds
if iterative:
self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds)
self.sum_fwds += self.chain.fweight[j]
else:
self.epsilon_tmp = max(
self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1))
self.sum_fwds = sum(self.chain.fweight[i:j + 1])
input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]
mf = act_size + df + input_size
mem_avail = self.max_memory - act_size - input_size
# if infeasible
if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail:
return None
eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf)
offload_data = self.sum_fwds * self.bandwidth + eps_f_beta
# TODO: Implement the precise backward recompute sequence mentioned in the paper
# currently we will use an approximate way to get the backward time
time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db)
prefetch_data = time_backward * self.bandwidth
idle_time = eps_f_beta / self.bandwidth
total_time = self.sum_fwds + idle_time + time_backward
return (offload_data, prefetch_data, total_time, idle_time)
def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple:
"""Generate new values for next state
Args:
state (Tuple): undiscretized states
do_offload (bool): bool type indicates whether we need to do offload
common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
Returns:
Tuple: (new_act_size, new_df, new_db)
"""
idx, act_size, df, db, input_has_bar = state
offload_data, prefetch_data, *_ = common_values
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
if do_offload:
new_act_size = act_size
new_df = max(0, df + input_size - offload_data)
new_db = max(0, db - prefetch_data) + input_size
else:
new_act_size = act_size + input_size
new_df = max(0, df - offload_data)
new_db = max(0, db - prefetch_data)
return (new_act_size, new_df, new_db)
def _compute_pofo_table(self):
self.table = PofoTable(self.length, self.mem_slots)
# initializing the loss
for act_size in range(self.mem_slots + 1):
for df in range(self.mem_slots - act_size + 1):
for db in range(self.mem_slots - act_size + 1):
# undiscretize for idle time calculation
origin_values = self._undiscretize(act_size, df, db)
for input_has_bar in (False, True):
disc_state = (self.length, act_size, df, db, input_has_bar)
state = (self.length, *origin_values, input_has_bar)
common_values = self._common_values_enable(state)
# if no feasible choice
if common_values is None:
self.table.set_value(disc_state, INF, None)
continue
# if there is feasible choice
new_act_size, new_df, new_db = self._new_values(state, False, common_values)
eps_g = (new_df + new_db) / self.bandwidth
total_time = common_values[2] + eps_g
self.table.set_value(disc_state, total_time, (True, False))
# main loop
for i in reversed(range(self.length)):
for act_size in range(self.mem_slots + 1):
for df in range(self.mem_slots - act_size + 1):
for db in range(self.mem_slots - act_size + 1):
# undiscretize for idle time calculation
origin_values = self._undiscretize(act_size, df, db)
for input_has_bar in (False, True):
best_result = INF
best_choice = None
disc_state = (i, act_size, df, db, input_has_bar)
state = (i, *origin_values, input_has_bar)
# case 1: start with F_all
vals_enable = self._common_values_enable(state)
if vals_enable is not None:
for do_offload in (True, False):
new_state = self._new_values(state, do_offload, vals_enable)
new_state = (i + 1, *self._discretize(*new_state), True)
total_time = vals_enable[2]
results_all = self.table.get_opt(new_state) + total_time
if results_all < best_result:
best_result = results_all
best_choice = (True, do_offload)
# case 2: start with F_ck
self.sum_fwds = 0
self.epsilon_tmp = 0
for j in range(i, self.length):
vals_nograd = self._common_values_nograd(state, j, True)
# if infeasible
if vals_nograd is None:
continue
for do_offload in (True, False):
new_state = self._new_values(state, do_offload, vals_nograd)
new_state = (j + 1, *self._discretize(*new_state), False)
total_time = vals_nograd[2]
result_nograd = total_time + self.table.get_opt(new_state)
if result_nograd < best_result:
best_result = result_nograd
best_choice = (False, do_offload, j)
self.table.set_value(disc_state, best_result, best_choice)
def pofo_rec(self, disc_state):
i, act_size, df, db, input_has_bar = disc_state
result = Sequence(Function("pofo", *disc_state))
what = self.table.get_what(disc_state)
state = self._undiscretize(act_size, df, db)
state = (i, *state, input_has_bar)
i, act_size, df, db, input_has_bar = state
if what is None:
return None
# if loss
if i == self.length:
result.insert(Loss())
return result
if what[0]:
do_offload = what[1]
values = self._common_values_enable(state)
new_state = self._discretize(*self._new_values(state, do_offload, values))
new_state = (i + 1, *new_state, True)
if do_offload:
result.insert(Offload(i, input_has_bar))
result.insert(ForwardEnable(i))
result.insert_sequence(self.pofo_rec(new_state))
if do_offload:
result.insert(Prefetch(i, input_has_bar))
result.insert(Backward(i))
else:
_, do_offload, j = what
values = self._common_values_nograd(state, j)
new_state = self._discretize(*self._new_values(state, do_offload, values))
new_state = (j + 1, *new_state, False)
if do_offload:
result.insert(Offload(i, input_has_bar))
result.insert(ForwardCheck(i))
for k in range(i + 1, j + 1):
result.insert(ForwardNograd(k))
result.insert_sequence(self.pofo_rec(new_state))
if do_offload:
result.insert(Prefetch(i, input_has_bar))
m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i])
#TODO: Implement the precise backward recompute sequence mentioned in the paper
result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db))
return result
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for op in fwd_list:
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [op.index]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(op.index)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
# annotate the offload
offload_idx = 0
for idx, op in enumerate(fwd_list):
if isinstance(op, Offload):
# corner case: offload input
if op.index == 0:
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", (offload_idx, True, False))
offload_idx += 1
else:
if op.has_bar:
# annotate previous node
if hasattr(node_list[op.index - 1][0], "activation_offload"):
for n in node_list[op.index - 1]:
n.activation_offload[-1] = True
else:
for n in node_list[op.index - 1]:
setattr(n, "activation_offload", [offload_idx, False, True])
offload_idx += 1
# annotate this node
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", [offload_idx, True, False])
offload_idx += 1
def solver_pofo(gm: ColoGraphModule,
data,
bandwidth,
flops,
mem_limit: int,
mem_slots: int = 50,
cnode: List[str] = None,
eps: float = 0.0) -> ColoGraphModule:
"""Solver that combine offload and activation checkpoint
Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
Args:
gm (ColoGraphModule): ColoGraphModule derived from tracer
data: input of the model
bandwidth: offload bandwidth, unit Byte/s
flops: FLOPS of device, unit FLOPs/s
mem_limit (int): memory limit, unit Byte
mem_slots (int, optional): number of memory slots. Defaults to 500.
cnode (List[str], optional): common node for linearize. Defaults to None.
eps (float, optional): epsilon for memory decay. Defaults to 0.02.
Returns:
ColoGraphModule: annotated graph module
"""
node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
# prepare data
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)
chain = _normalize_flops(chain, flops)
# currently we view loss as an op without expense
chain.cbweight.append(0)
chain.cweight.append(0)
chain.fwd_mem_tmp.append(0)
chain.bwd_mem_tmp.append(0)
chain.fweight.append(0)
chain.bweight.append(0)
solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots)
first_state = (0, 0, 0, 0, False)
sequence = solver.pofo_rec(first_state)
if sequence == None:
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
_annotate_from_pofo_sequence(sequence, node_list)
setattr(gm, "__sequence__", sequence)
return gm

View File

@ -1,436 +0,0 @@
import math
import sys
from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
(False, j) if the optimal choice is a leaf checkpoint of length j
The computation uses dynamic programming"""
fw = chain.fweight + [0] ## forward time
bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
opt[m][i][i] = float("inf")
# Compute everything
for m in range(mmax + 1):
for d in range(1, chain.length + 1):
for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1):
idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin:
opt[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= cw[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= cbw[i + 1]:
chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
opt[m][i][idx] = best_leaf[1]
what[m][i][idx] = (False, best_leaf[0])
else:
opt[m][i][idx] = chain_checkpoint
what[m][i][idx] = (True,)
return (opt, what)
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
""" chain : the class describing the AC graph
lmin : index of the first forward to execute
lmax : upper bound index of the last forward to execute (not included)
cmem : number of available memory slots
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
if cmem <= 0:
raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
opt, what = opt_table
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
if opt[cmem][lmin][lmax] == float("inf"):
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))
# set global indicater SOLVER_FAILED to True
global SOLVER_FAILED
SOLVER_FAILED = True
return sequence
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
else:
sequence.insert(ForwardEnable(lmin))
sequence.insert(Backward(lmin))
return sequence
if what[cmem][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
sequence.insert(Backward(lmin))
else:
j = what[cmem][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k))
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
return sequence
def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: xbar size, unit Byte
"""
xbar = 0
for n in node:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
return xbar
def _fwd_time(node: List[Node]) -> int:
"""Get the foward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: foward time, extimated by flops count
"""
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time
def _bwd_time(node: List[Node]) -> int:
"""Get the backward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward time, extimated by flops count
"""
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time
def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward temp memory, unit Byte
"""
def _get_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
bwd_mem_tmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], input) -> Chain:
fwd_time = []
bwd_time = []
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
tmp_fwd = []
tmp_bwd = []
for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(calculate_fwd_out(node[-1]))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_fwd.append(_get_fwd_mem_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))
bwd_time.append(0)
# currently we view loss backward temp as zero
tmp_bwd.append(0)
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
def solver_rotor(gm: ColoGraphModule,
data,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.0,
force_python: bool = False) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
force_python (bool): force to use python version of dynamic programs
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
# try to import C version solver if force_python is not set
logger = get_dist_logger()
if not force_python:
try:
from .dynamic_programs_C_version import persistent_compute_table
CVERSION = True
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
from .dynamic_programs_C_version import persistent_compute_table
CVERSION = True
else:
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
CVERSION = False
else:
CVERSION = False
# check if metainfoprop is done
if any(len(node.meta) == 0 for node in gm.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
# linearize the graph
node_list = linearize(gm, cnode)
# construct chain
mem_unit = mem_limit * (1.0 - eps) // mem_slots
chain: Chain = _construct_chain(node_list, data)
chain._discretize(mem_unit)
# use C version if possible
if CVERSION and not force_python:
logger.info("Using C version rotor solver!", ranks=[0])
opt_table = persistent_compute_table(chain, mem_slots)
else:
opt_table = _compute_table(chain, mem_slots)
logger.info("Using python version rotor solver!", ranks=[0])
# found sequence
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
# if solver failed, we don't need to annotate the graph
if not SOLVER_FAILED:
_annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule
if SOLVER_FAILED:
setattr(gm, "__sequence__", None)
else:
setattr(gm, "__sequence__", sequence)
# set __opttable__ attribute to GraphModule
setattr(gm, "__opttable__", opt_table[0])
gm.recompile()
return gm

View File

@ -1,516 +0,0 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweight");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweight");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweight");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweight");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
if (!cbw) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long* what = (long*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i)
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i) = fw[i] + bw[i];
else
OPT(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) {
for (long i = 0; i <= chain_length - d; ++i) {
long idx = i + d;
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
}
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= idx; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
OPT(m, i, idx) = bestLeafCost;
WHAT(m, i, idx) = bestLeaf;
} else {
OPT(m, i, idx) = chainCost;
WHAT(m, i, idx) = -1;
}
} else
OPT(m, i, idx) = INFINITY;
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_what, m, res_what_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
PyObject* res_what_m_i = PyDict_New();
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
PyObject* res_what_m_i_l;
long what_m_i_l = WHAT(m, i, l);
if (what_m_i_l < 0)
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
else
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
Py_DECREF(res_what_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
// long i = L - s, j = t - s, k = l - t
inline long floating_index_in_array(long m_factor, long m, long i, long j,
long k) {
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
(j * (j - 1)) / 2 + k;
}
typedef struct {
long sp;
long r;
long tp;
} index_t;
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
if (!fwd_tmp) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
if (!bwd_tmp) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
const long m_factor =
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
partialSumsFW[i] = total;
total += fw[i];
}
partialSumsFW[chain_length] = total;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i, i) = fw[i] + bw[i];
else
OPT(m, i, i, i) = INFINITY;
}
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) { // d = l - s
for (long s = 0; s <= chain_length - d; ++s) {
long l = s + d;
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
long memNullSecond = 0;
for (long j = s + 1; j < l; ++j) {
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
if (val > memNullSecond) memNullSecond = val;
}
for (long t = s; t <= l; ++t) {
double chainCost = INFINITY;
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
}
double bestLeafCost = INFINITY;
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
for (long r = s; r <= t; ++r)
if (cw[s] <= cw[r])
for (long tp = t + 1; tp <= l; ++tp)
for (long sp = r + 1; sp <= tp; ++sp) {
long mp = m - cw[r] + cw[s];
assert(mp >= 0);
if (mp >= cw[sp]) {
double value = partialSumsFW[sp] - partialSumsFW[s] +
OPT(mp - cw[sp], sp, tp, l) +
OPT(mp, r, t, tp - 1);
if (value < bestLeafCost) {
bestLeafCost = value;
bestLeaf.sp = sp;
bestLeaf.r = r;
bestLeaf.tp = tp;
}
}
}
}
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
OPT(m, s, t, l) = bestLeafCost;
WHAT(m, s, t, l).sp = bestLeaf.sp;
WHAT(m, s, t, l).r = bestLeaf.r;
WHAT(m, s, t, l).tp = bestLeaf.tp;
} else {
OPT(m, s, t, l) = chainCost;
WHAT(m, s, t, l).sp = -1;
}
}
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyDict_New();
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyDict_New();
PyList_SET_ITEM(res_what, m, res_what_m);
for (long s = 0; s <= chain_length; ++s)
for (long t = s; t <= chain_length; ++t)
for (long l = t; l <= chain_length; ++l) {
PyObject* key = Py_BuildValue("(lll)", s, t, l);
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
PyDict_SetItem(res_opt_m, key, value_opt);
PyObject* value_what = true_tuple;
index_t* idx_what = &WHAT(m, s, t, l);
if (idx_what->sp >= 0)
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
idx_what->r, idx_what->tp);
PyDict_SetItem(res_what_m, key, value_what);
if (value_what != true_tuple) Py_DECREF(value_what);
Py_DECREF(key);
Py_DECREF(value_opt);
}
}
Py_DECREF(true_tuple);
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
// Compute partial sums
double* sumfw = (double*)calloc(chain_length, sizeof(double));
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
total += fw[i];
sumfw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length + 1; ++i) {
total += bw[i];
sumbw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length; ++i) {
total += sumfw[i];
sumsumfw[i] = total;
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
OPT(m, i, i) = bw[i];
else
OPT(m, i, i) = INFINITY;
if (i < chain_length) {
long maxC = fmaxl(cw[i], cw[i + 1]);
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
else
OPT(m, i, i + 1) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i + 2 <= chain_length; ++i) {
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
for (long l = i + 2; l <= chain_length; ++l) {
maxCW_il = fmax(maxCW_il, cw[l + 1]);
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
long mmin = fmaxl(mminCst, maxCostFWD);
if ((m >= mmin)) {
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
noCheckpointCost +=
sumsumfw[l - 1] -
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
double valueCost = INFINITY;
if (m >= cw[i]) {
double sumFwds = 0;
for (long j = i + 1; j < l; ++j) {
sumFwds += fw[j - 1];
valueCost = fmin(
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
}
}
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
} else
OPT(m, i, l) = INFINITY;
}
}
free(sumfw);
free(sumbw);
free(sumsumfw);
free(fw);
free(bw);
free(cw);
free(cbw);
PyObject* res_opt = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l - i);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
return res_opt;
}
static PyMethodDef dynamic_programs_methods[] = {
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
"Compute the optimal table with the persistent algorithm."},
{"floating_compute_table", floating_compute_table, METH_VARARGS,
"Compute the optimal table with the floating algorithm."},
{"griewank_heterogeneous_compute_table",
griewank_heterogeneous_compute_table, METH_VARARGS,
"Compute the optimal table for the Griewank Heterogeneous Model."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef dynamic_programs_module = {
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods};
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
return PyModule_Create(&dynamic_programs_module);
}

View File

@ -1,94 +0,0 @@
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in COPS
else:
return target.__name__ in COPS
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if cnode:
for name in cnode:
try:
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
f"common node {name} is not an input of the model"
except StopIteration:
raise ValueError(f"common node name {name} not in graph")
else:
cnode = []
deps = {}
linearized_nodes = []
region = []
for n in gm.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n._input_nodes:
if n_par.op != "placeholder" and n_par.name not in cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
# propagate common node attr if possible
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return linearized_nodes

View File

@ -1,270 +0,0 @@
import math
def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
class Chain:
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
self.fweight = fw
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()
def _discretize(self, mem_unit):
self.cweight = _discretize(mem_unit, self.cweight)
self.cbweight = _discretize(mem_unit, self.cbweight)
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
class Operation:
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Offload(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Off"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Prefetch(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Pre"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Forward(Operation):
def __init__(self, index):
self.index = index
self.name = "F"
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.fweight[self.index]
else:
return 1
class ForwardEnable(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fe"
class ForwardNograd(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fn"
class ForwardCheck(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "B_{i}".format(i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.bweight[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "WM"
class ReadMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "RM"
class DiscardMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "DM"
class Function:
def __init__(self, name, *args):
self.name = name
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self.sequence:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)