mirror of https://github.com/hpcaitech/ColossalAI
parent
55dcd3051a
commit
b42d3d28ed
|
@ -1,4 +0,0 @@
|
|||
from .ckpt_solver_chen import chen_greedy
|
||||
from .linearize import linearize
|
||||
from .ckpt_solver_rotor import solver_rotor
|
||||
from .ckpt_solver_pofo import solver_pofo
|
|
@ -1,15 +0,0 @@
|
|||
from setuptools import setup, Extension
|
||||
import os
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_modules = [Extension(
|
||||
'dynamic_programs_C_version',
|
||||
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
|
||||
)]
|
||||
|
||||
setup(
|
||||
name='rotor c extension',
|
||||
version='0.1',
|
||||
description='rotor c extension for faster dp computing',
|
||||
ext_modules=ext_modules,
|
||||
)
|
|
@ -1,98 +0,0 @@
|
|||
import math
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
__all__ = ['chen_greedy']
|
||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
|
||||
|
||||
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
||||
"""
|
||||
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
||||
"""
|
||||
|
||||
def is_sink():
|
||||
"""
|
||||
If we can free all memories when executing a certain node, it is a sink.
|
||||
"""
|
||||
return not sum((v for k, v in deps.items()))
|
||||
|
||||
deps = {}
|
||||
ckpt_nodes = []
|
||||
for n in gm.graph.nodes:
|
||||
for n_par in n._input_nodes:
|
||||
deps[n_par] -= 1 # free memory and dependencies
|
||||
|
||||
# We can only put act_ckpt on these nodes
|
||||
if n.op in CKPT_OP and is_sink():
|
||||
ckpt_nodes.append(n)
|
||||
deps[n] = len(n.users) # add dependencies for future executions
|
||||
return ckpt_nodes
|
||||
|
||||
|
||||
def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
model = resnet18()
|
||||
input_sample = torch.rand(4, 3, 224, 224)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
gm = chen_greedy(gm)
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The module to add checkpoints
|
||||
"""
|
||||
|
||||
def grid_search(num_grids: int = 6) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||
"""
|
||||
_, b_approx = run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
b_opt = math.inf
|
||||
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
||||
ckpt_intv, b_approx = run_chen_greedy(b)
|
||||
if b_approx < b_opt:
|
||||
b_opt = b_approx
|
||||
ckpt_opt = ckpt_intv
|
||||
return ckpt_opt
|
||||
|
||||
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
"""
|
||||
ckpt_nodes = _all_potential_ckpt_nodes(gm)
|
||||
ckpt_intv = []
|
||||
temp = 0
|
||||
x = 0
|
||||
y = 0
|
||||
prev_idx = 2
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
n: Node
|
||||
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
||||
y = max(y, temp)
|
||||
if temp > b and n in ckpt_nodes:
|
||||
x += calculate_fwd_in(n)
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
return ckpt_intv, math.floor(math.sqrt(x * y))
|
||||
|
||||
gm.graph.lint() # make sure nodes are in topological order
|
||||
ckpt = grid_search(num_grids=6)
|
||||
node_list = list(gm.graph.nodes)
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
n = node_list[idx]
|
||||
if n.op in CKPT_OP:
|
||||
setattr(n, 'activation_checkpoint', i)
|
||||
gm.recompile()
|
||||
return gm
|
|
@ -1,537 +0,0 @@
|
|||
import copy
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import \
|
||||
_find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
|
||||
Sequence)
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
|
||||
def _normalize_flops(chain: Chain, flops) -> Chain:
|
||||
"""
|
||||
Normalize flops
|
||||
"""
|
||||
for i in range(chain.length):
|
||||
chain.fweight[i] /= flops
|
||||
chain.bweight[i] /= flops
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
class PofoTable:
|
||||
"""PofoTable
|
||||
The PofoTable contains the necessary components to store intermediate results
|
||||
of dynamic programming and the operations alone the way.
|
||||
"""
|
||||
|
||||
def __init__(self, chain_length: int, mem_slots: int):
|
||||
"""Init pofo table
|
||||
The pofo table contains two tables, opt and what, indicating values and
|
||||
operations.
|
||||
|
||||
Args:
|
||||
chain_length (int): chain length
|
||||
mem_slots (int): number of memory slots
|
||||
"""
|
||||
|
||||
self.length = chain_length
|
||||
self.mem_slots = mem_slots
|
||||
|
||||
# initializing tables
|
||||
# the first bool indicates whether the input has bar
|
||||
# opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
|
||||
# what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
|
||||
# where is_enable indicates whether we enable the gradient, is_offload indicates whether we
|
||||
# offload the input, index indicates the end of F_\empty sequence if is_enable = False
|
||||
self.opt = {
|
||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
||||
}
|
||||
self.what = {
|
||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
||||
}
|
||||
|
||||
def _get_value(self, state, table, default):
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
if act_size + df > self.mem_slots or act_size + db > self.mem_slots:
|
||||
return default
|
||||
|
||||
try:
|
||||
return table[input_has_bar][i][act_size][(df, db)]
|
||||
except KeyError:
|
||||
print(f"state not found {state}")
|
||||
|
||||
def get_opt(self, state):
|
||||
return self._get_value(state, self.opt, INF)
|
||||
|
||||
def get_what(self, state):
|
||||
return self._get_value(state, self.what, INF)
|
||||
|
||||
def set_value(self, state, opt, what):
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
self.opt[input_has_bar][i][act_size][(df, db)] = opt
|
||||
self.what[input_has_bar][i][act_size][(df, db)] = what
|
||||
|
||||
|
||||
class PofoSolver:
|
||||
"""PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
||||
The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
|
||||
and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
|
||||
rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
|
||||
"""
|
||||
|
||||
def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None:
|
||||
self.chain = chain
|
||||
self.length = chain.length
|
||||
self.max_memory = max_memory
|
||||
self.mem_slots = mem_slots
|
||||
self.mem_unit = max_memory / mem_slots
|
||||
self.bandwidth = bandwidth
|
||||
|
||||
self.disc_chain = copy.deepcopy(self.chain)
|
||||
self.disc_chain._discretize(self.mem_unit)
|
||||
|
||||
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
|
||||
self._compute_pofo_table()
|
||||
|
||||
def _discretize(self, *values) -> Tuple:
|
||||
return tuple(math.ceil(value / self.mem_unit) for value in values)
|
||||
|
||||
def _undiscretize(self, *discrete_values) -> Tuple:
|
||||
if len(discrete_values) == 1:
|
||||
return discrete_values[0] * self.mem_unit
|
||||
else:
|
||||
return tuple(d * self.mem_unit for d in discrete_values)
|
||||
|
||||
def _mmax_all(self, idx: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of Fi_all
|
||||
"""
|
||||
|
||||
return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx]
|
||||
|
||||
def _mmax_b(self, idx: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of Bi
|
||||
"""
|
||||
|
||||
return self.chain.cbweight[idx +
|
||||
1] + self.chain.cweight[idx +
|
||||
1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx]
|
||||
|
||||
def _mmax_ng(self, i: int, j: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
|
||||
"""
|
||||
|
||||
res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j]
|
||||
if j > i:
|
||||
res += self.chain.cweight[j]
|
||||
return res
|
||||
|
||||
def _rotor_estimated_bwd(self, i, j, m, delta):
|
||||
compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j]
|
||||
comm = delta / self.bandwidth
|
||||
return (max(compute, comm) + compute + comm) / 2
|
||||
|
||||
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
|
||||
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
|
||||
|
||||
def _common_values_enable(self, state: Tuple):
|
||||
|
||||
idx, act_size, df, db, input_has_bar = state
|
||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
||||
mf = act_size + df + input_size
|
||||
mb = act_size + db + input_size
|
||||
mem_avail = self.max_memory - act_size - input_size
|
||||
f_usage = self._mmax_all(idx)
|
||||
b_usage = self._mmax_b(idx)
|
||||
|
||||
# infeasible
|
||||
if f_usage > mem_avail or b_usage > mem_avail:
|
||||
return None
|
||||
|
||||
# calculate idle time
|
||||
eps_f_beta = max(0, f_usage - self.max_memory + mf)
|
||||
eps_b_beta = max(0, b_usage - self.max_memory + mb)
|
||||
idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth
|
||||
|
||||
# calculate offload and prefetch data
|
||||
offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta
|
||||
prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta
|
||||
|
||||
# total_time
|
||||
total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time
|
||||
|
||||
return (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False):
|
||||
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
|
||||
# compute new epsilon_tmp and sum_fwds
|
||||
if iterative:
|
||||
self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds)
|
||||
self.sum_fwds += self.chain.fweight[j]
|
||||
else:
|
||||
self.epsilon_tmp = max(
|
||||
self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1))
|
||||
self.sum_fwds = sum(self.chain.fweight[i:j + 1])
|
||||
|
||||
input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]
|
||||
mf = act_size + df + input_size
|
||||
mem_avail = self.max_memory - act_size - input_size
|
||||
|
||||
# if infeasible
|
||||
if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail:
|
||||
return None
|
||||
|
||||
eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf)
|
||||
offload_data = self.sum_fwds * self.bandwidth + eps_f_beta
|
||||
|
||||
# TODO: Implement the precise backward recompute sequence mentioned in the paper
|
||||
# currently we will use an approximate way to get the backward time
|
||||
time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db)
|
||||
|
||||
prefetch_data = time_backward * self.bandwidth
|
||||
idle_time = eps_f_beta / self.bandwidth
|
||||
total_time = self.sum_fwds + idle_time + time_backward
|
||||
|
||||
return (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple:
|
||||
"""Generate new values for next state
|
||||
|
||||
Args:
|
||||
state (Tuple): undiscretized states
|
||||
do_offload (bool): bool type indicates whether we need to do offload
|
||||
common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
Returns:
|
||||
Tuple: (new_act_size, new_df, new_db)
|
||||
"""
|
||||
idx, act_size, df, db, input_has_bar = state
|
||||
offload_data, prefetch_data, *_ = common_values
|
||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
||||
if do_offload:
|
||||
new_act_size = act_size
|
||||
new_df = max(0, df + input_size - offload_data)
|
||||
new_db = max(0, db - prefetch_data) + input_size
|
||||
else:
|
||||
new_act_size = act_size + input_size
|
||||
new_df = max(0, df - offload_data)
|
||||
new_db = max(0, db - prefetch_data)
|
||||
|
||||
return (new_act_size, new_df, new_db)
|
||||
|
||||
def _compute_pofo_table(self):
|
||||
self.table = PofoTable(self.length, self.mem_slots)
|
||||
|
||||
# initializing the loss
|
||||
for act_size in range(self.mem_slots + 1):
|
||||
for df in range(self.mem_slots - act_size + 1):
|
||||
for db in range(self.mem_slots - act_size + 1):
|
||||
# undiscretize for idle time calculation
|
||||
origin_values = self._undiscretize(act_size, df, db)
|
||||
|
||||
for input_has_bar in (False, True):
|
||||
disc_state = (self.length, act_size, df, db, input_has_bar)
|
||||
state = (self.length, *origin_values, input_has_bar)
|
||||
common_values = self._common_values_enable(state)
|
||||
|
||||
# if no feasible choice
|
||||
if common_values is None:
|
||||
self.table.set_value(disc_state, INF, None)
|
||||
continue
|
||||
|
||||
# if there is feasible choice
|
||||
new_act_size, new_df, new_db = self._new_values(state, False, common_values)
|
||||
eps_g = (new_df + new_db) / self.bandwidth
|
||||
total_time = common_values[2] + eps_g
|
||||
self.table.set_value(disc_state, total_time, (True, False))
|
||||
|
||||
# main loop
|
||||
for i in reversed(range(self.length)):
|
||||
for act_size in range(self.mem_slots + 1):
|
||||
for df in range(self.mem_slots - act_size + 1):
|
||||
for db in range(self.mem_slots - act_size + 1):
|
||||
# undiscretize for idle time calculation
|
||||
origin_values = self._undiscretize(act_size, df, db)
|
||||
|
||||
for input_has_bar in (False, True):
|
||||
best_result = INF
|
||||
best_choice = None
|
||||
disc_state = (i, act_size, df, db, input_has_bar)
|
||||
state = (i, *origin_values, input_has_bar)
|
||||
|
||||
# case 1: start with F_all
|
||||
vals_enable = self._common_values_enable(state)
|
||||
if vals_enable is not None:
|
||||
for do_offload in (True, False):
|
||||
new_state = self._new_values(state, do_offload, vals_enable)
|
||||
new_state = (i + 1, *self._discretize(*new_state), True)
|
||||
total_time = vals_enable[2]
|
||||
results_all = self.table.get_opt(new_state) + total_time
|
||||
if results_all < best_result:
|
||||
best_result = results_all
|
||||
best_choice = (True, do_offload)
|
||||
|
||||
# case 2: start with F_ck
|
||||
self.sum_fwds = 0
|
||||
self.epsilon_tmp = 0
|
||||
for j in range(i, self.length):
|
||||
vals_nograd = self._common_values_nograd(state, j, True)
|
||||
|
||||
# if infeasible
|
||||
if vals_nograd is None:
|
||||
continue
|
||||
|
||||
for do_offload in (True, False):
|
||||
new_state = self._new_values(state, do_offload, vals_nograd)
|
||||
new_state = (j + 1, *self._discretize(*new_state), False)
|
||||
total_time = vals_nograd[2]
|
||||
result_nograd = total_time + self.table.get_opt(new_state)
|
||||
if result_nograd < best_result:
|
||||
best_result = result_nograd
|
||||
best_choice = (False, do_offload, j)
|
||||
|
||||
self.table.set_value(disc_state, best_result, best_choice)
|
||||
|
||||
def pofo_rec(self, disc_state):
|
||||
i, act_size, df, db, input_has_bar = disc_state
|
||||
result = Sequence(Function("pofo", *disc_state))
|
||||
what = self.table.get_what(disc_state)
|
||||
state = self._undiscretize(act_size, df, db)
|
||||
state = (i, *state, input_has_bar)
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
|
||||
if what is None:
|
||||
return None
|
||||
|
||||
# if loss
|
||||
if i == self.length:
|
||||
result.insert(Loss())
|
||||
return result
|
||||
|
||||
if what[0]:
|
||||
do_offload = what[1]
|
||||
values = self._common_values_enable(state)
|
||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
||||
new_state = (i + 1, *new_state, True)
|
||||
if do_offload:
|
||||
result.insert(Offload(i, input_has_bar))
|
||||
result.insert(ForwardEnable(i))
|
||||
result.insert_sequence(self.pofo_rec(new_state))
|
||||
if do_offload:
|
||||
result.insert(Prefetch(i, input_has_bar))
|
||||
result.insert(Backward(i))
|
||||
|
||||
else:
|
||||
_, do_offload, j = what
|
||||
values = self._common_values_nograd(state, j)
|
||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
||||
new_state = (j + 1, *new_state, False)
|
||||
if do_offload:
|
||||
result.insert(Offload(i, input_has_bar))
|
||||
result.insert(ForwardCheck(i))
|
||||
for k in range(i + 1, j + 1):
|
||||
result.insert(ForwardNograd(k))
|
||||
result.insert_sequence(self.pofo_rec(new_state))
|
||||
if do_offload:
|
||||
result.insert(Prefetch(i, input_has_bar))
|
||||
m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i])
|
||||
|
||||
#TODO: Implement the precise backward recompute sequence mentioned in the paper
|
||||
result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
|
||||
# forward annotation
|
||||
for op in fwd_list:
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
||||
|
||||
# annotate the offload
|
||||
offload_idx = 0
|
||||
for idx, op in enumerate(fwd_list):
|
||||
if isinstance(op, Offload):
|
||||
# corner case: offload input
|
||||
if op.index == 0:
|
||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", True)
|
||||
else:
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", (offload_idx, True, False))
|
||||
offload_idx += 1
|
||||
|
||||
else:
|
||||
if op.has_bar:
|
||||
# annotate previous node
|
||||
if hasattr(node_list[op.index - 1][0], "activation_offload"):
|
||||
for n in node_list[op.index - 1]:
|
||||
n.activation_offload[-1] = True
|
||||
else:
|
||||
for n in node_list[op.index - 1]:
|
||||
setattr(n, "activation_offload", [offload_idx, False, True])
|
||||
|
||||
offload_idx += 1
|
||||
|
||||
# annotate this node
|
||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", True)
|
||||
else:
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", [offload_idx, True, False])
|
||||
|
||||
offload_idx += 1
|
||||
|
||||
|
||||
def solver_pofo(gm: ColoGraphModule,
|
||||
data,
|
||||
bandwidth,
|
||||
flops,
|
||||
mem_limit: int,
|
||||
mem_slots: int = 50,
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.0) -> ColoGraphModule:
|
||||
"""Solver that combine offload and activation checkpoint
|
||||
Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
||||
|
||||
Args:
|
||||
gm (ColoGraphModule): ColoGraphModule derived from tracer
|
||||
data: input of the model
|
||||
bandwidth: offload bandwidth, unit Byte/s
|
||||
flops: FLOPS of device, unit FLOPs/s
|
||||
mem_limit (int): memory limit, unit Byte
|
||||
mem_slots (int, optional): number of memory slots. Defaults to 500.
|
||||
cnode (List[str], optional): common node for linearize. Defaults to None.
|
||||
eps (float, optional): epsilon for memory decay. Defaults to 0.02.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated graph module
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain = _normalize_flops(chain, flops)
|
||||
# currently we view loss as an op without expense
|
||||
chain.cbweight.append(0)
|
||||
chain.cweight.append(0)
|
||||
chain.fwd_mem_tmp.append(0)
|
||||
chain.bwd_mem_tmp.append(0)
|
||||
chain.fweight.append(0)
|
||||
chain.bweight.append(0)
|
||||
|
||||
solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots)
|
||||
first_state = (0, 0, 0, 0, False)
|
||||
sequence = solver.pofo_rec(first_state)
|
||||
if sequence == None:
|
||||
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
|
||||
|
||||
_annotate_from_pofo_sequence(sequence, node_list)
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
return gm
|
|
@ -1,436 +0,0 @@
|
|||
import math
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||
|
||||
# global vairable to indicate whether the solver is failed
|
||||
SOLVER_FAILED = False
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
# paper link: https://hal.inria.fr/hal-02352969
|
||||
def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
Opt[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
|
||||
(False, j) if the optimal choice is a leaf checkpoint of length j
|
||||
The computation uses dynamic programming"""
|
||||
|
||||
fw = chain.fweight + [0] ## forward time
|
||||
bw = chain.bweight ## backward time, not used
|
||||
cw = chain.cweight + [0] ## size of x (and of y)
|
||||
cbw = chain.cbweight + [0] ## size of xbar
|
||||
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
|
||||
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
|
||||
|
||||
# Build table
|
||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mmax + 1):
|
||||
for i in range(chain.length + 1):
|
||||
#lmax-lmin = 0
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
|
||||
if m >= limit: ## Equation (1)
|
||||
opt[m][i][i] = fw[i] + bw[i]
|
||||
else:
|
||||
opt[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
for m in range(mmax + 1):
|
||||
for d in range(1, chain.length + 1):
|
||||
for i in range(chain.length + 1 - d):
|
||||
# for idx in range(i+1, chain.length + 1):
|
||||
idx = i + d
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
opt[m][i][idx] = float("inf")
|
||||
else:
|
||||
leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= cw[j]]
|
||||
if leaf_checkpoints:
|
||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
||||
else:
|
||||
best_leaf = None
|
||||
if m >= cbw[i + 1]:
|
||||
chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
|
||||
else:
|
||||
chain_checkpoint = float("inf")
|
||||
if best_leaf and best_leaf[1] <= chain_checkpoint:
|
||||
opt[m][i][idx] = best_leaf[1]
|
||||
what[m][i][idx] = (False, best_leaf[0])
|
||||
else:
|
||||
opt[m][i][idx] = chain_checkpoint
|
||||
what[m][i][idx] = (True,)
|
||||
return (opt, what)
|
||||
|
||||
|
||||
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
||||
""" chain : the class describing the AC graph
|
||||
lmin : index of the first forward to execute
|
||||
lmax : upper bound index of the last forward to execute (not included)
|
||||
cmem : number of available memory slots
|
||||
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
|
||||
if cmem <= 0:
|
||||
raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
|
||||
opt, what = opt_table
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
|
||||
if opt[cmem][lmin][lmax] == float("inf"):
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
|
||||
# set global indicater SOLVER_FAILED to True
|
||||
global SOLVER_FAILED
|
||||
SOLVER_FAILED = True
|
||||
return sequence
|
||||
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
else:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert(Backward(lmin))
|
||||
return sequence
|
||||
|
||||
if what[cmem][lmin][lmax][0]:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
|
||||
sequence.insert(Backward(lmin))
|
||||
else:
|
||||
j = what[cmem][lmin][lmax][1]
|
||||
sequence.insert(ForwardCheck(lmin))
|
||||
for k in range(lmin + 1, j):
|
||||
sequence.insert(ForwardNograd(k))
|
||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
|
||||
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
||||
return sequence
|
||||
|
||||
|
||||
def _fwd_xbar(node: List[Node]) -> int:
|
||||
"""Get the forward xbar of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: xbar size, unit Byte
|
||||
"""
|
||||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
return xbar
|
||||
|
||||
|
||||
def _fwd_time(node: List[Node]) -> int:
|
||||
"""Get the foward time of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: foward time, extimated by flops count
|
||||
"""
|
||||
|
||||
fwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
fwd_time += max(n.meta['fwd_flop'], 1)
|
||||
return fwd_time
|
||||
|
||||
|
||||
def _bwd_time(node: List[Node]) -> int:
|
||||
"""Get the backward time of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: backward time, extimated by flops count
|
||||
"""
|
||||
|
||||
bwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
bwd_time += max(n.meta['bwd_flop'], 1)
|
||||
return bwd_time
|
||||
|
||||
|
||||
def _get_fwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the forward temp memory of a node
|
||||
This could be done by subtracting the saved activation from all output of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: forward temp memory, unit Byte
|
||||
"""
|
||||
n = node[-1]
|
||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||
|
||||
|
||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the backward temp memory of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: backward temp memory, unit Byte
|
||||
"""
|
||||
|
||||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||
|
||||
return deps_size
|
||||
|
||||
bwd_mem_tmp = 0
|
||||
deps = {}
|
||||
|
||||
for n in reversed(node):
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
if deps[child] <= 0:
|
||||
deps[child] = float('-inf') # free
|
||||
|
||||
return bwd_mem_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], input) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
bwd_time = []
|
||||
xbar_sizes = [activation_size(input)]
|
||||
x_sizes = [activation_size(input)]
|
||||
tmp_fwd = []
|
||||
tmp_bwd = []
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
fwd_time.append(_fwd_time(node))
|
||||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(calculate_fwd_out(node[-1]))
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_fwd.append(_get_fwd_mem_tmp(node))
|
||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||
|
||||
bwd_time.append(0)
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
tmp_bwd.append(0)
|
||||
|
||||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||
|
||||
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
|
||||
# forward annotation
|
||||
for idx, op in enumerate(fwd_list, 0):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(idx)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
||||
|
||||
|
||||
def solver_rotor(gm: ColoGraphModule,
|
||||
data,
|
||||
mem_limit: int,
|
||||
mem_slots: int = 500,
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.0,
|
||||
force_python: bool = False) -> ColoGraphModule:
|
||||
"""solver that automatically find activation checkpoint in rotor's manner
|
||||
|
||||
Args:
|
||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
|
||||
data (torch.Tensor): input data.
|
||||
mem_limit (int): memory budget in Byte.
|
||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||
eps (float): epsilon for memory decay. Defaults to 0.0
|
||||
force_python (bool): force to use python version of dynamic programs
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
"""
|
||||
|
||||
# try to import C version solver if force_python is not set
|
||||
logger = get_dist_logger()
|
||||
if not force_python:
|
||||
try:
|
||||
from .dynamic_programs_C_version import persistent_compute_table
|
||||
CVERSION = True
|
||||
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import os
|
||||
import subprocess
|
||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
[
|
||||
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
|
||||
f"--build-lib={this_dir}"
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if result.wait() == 0:
|
||||
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
|
||||
from .dynamic_programs_C_version import persistent_compute_table
|
||||
CVERSION = True
|
||||
else:
|
||||
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
|
||||
CVERSION = False
|
||||
else:
|
||||
CVERSION = False
|
||||
|
||||
# check if metainfoprop is done
|
||||
if any(len(node.meta) == 0 for node in gm.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
|
||||
|
||||
# linearize the graph
|
||||
node_list = linearize(gm, cnode)
|
||||
|
||||
# construct chain
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain._discretize(mem_unit)
|
||||
|
||||
# use C version if possible
|
||||
if CVERSION and not force_python:
|
||||
logger.info("Using C version rotor solver!", ranks=[0])
|
||||
opt_table = persistent_compute_table(chain, mem_slots)
|
||||
else:
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
logger.info("Using python version rotor solver!", ranks=[0])
|
||||
|
||||
# found sequence
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
|
||||
# if solver failed, we don't need to annotate the graph
|
||||
if not SOLVER_FAILED:
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
if SOLVER_FAILED:
|
||||
setattr(gm, "__sequence__", None)
|
||||
else:
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
|
||||
# set __opttable__ attribute to GraphModule
|
||||
setattr(gm, "__opttable__", opt_table[0])
|
||||
gm.recompile()
|
||||
return gm
|
|
@ -1,516 +0,0 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
long* PySequenceToLongArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
long* result = (long*)calloc(len + 1, sizeof(long));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyLong_AsLong(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
double* PySequenceToDoubleArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
double* result = (double*)calloc(len + 1, sizeof(double));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyFloat_AsDouble(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
long* getLongArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
long* result = PySequenceToLongArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
double* result = PySequenceToDoubleArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweight");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweight");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweight");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweight");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
// TODO: Can be optimized by only allocating memory for l >= i
|
||||
// TODO: float / int instead of double / long ?
|
||||
#define OPT(m, i, l) \
|
||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
double* opt = (double*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
||||
|
||||
#define WHAT(m, i, l) \
|
||||
what[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
long* what = (long*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i)
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
||||
OPT(m, i, i) = fw[i] + bw[i];
|
||||
else
|
||||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long d = 1; d <= chain_length; ++d) {
|
||||
for (long i = 0; i <= chain_length - d; ++i) {
|
||||
long idx = i + d;
|
||||
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (idx > i + 1) {
|
||||
long maxCostFWD = 0;
|
||||
for (long j = i + 1; j < idx; j++) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
|
||||
}
|
||||
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
double sumFw = 0;
|
||||
double bestLeafCost = INFINITY;
|
||||
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
|
||||
/// i+1
|
||||
for (long j = i + 1; j <= idx; ++j) {
|
||||
sumFw += fw[j - 1];
|
||||
if (m >= cw[j]) {
|
||||
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= cbw[i + 1])
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
OPT(m, i, idx) = bestLeafCost;
|
||||
WHAT(m, i, idx) = bestLeaf;
|
||||
} else {
|
||||
OPT(m, i, idx) = chainCost;
|
||||
WHAT(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
OPT(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
free(fwd_tmp);
|
||||
free(bwd_tmp);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
PyObject* res_what = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
PyObject* res_what_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
PyObject* res_opt_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
||||
PyObject* res_what_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
|
||||
for (long l = i; l <= chain_length; ++l) {
|
||||
PyObject* res_l = PyLong_FromLong(l);
|
||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
||||
Py_DECREF(res_opt_m_i_l);
|
||||
PyObject* res_what_m_i_l;
|
||||
long what_m_i_l = WHAT(m, i, l);
|
||||
if (what_m_i_l < 0)
|
||||
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||
else
|
||||
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
|
||||
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
|
||||
Py_DECREF(res_what_m_i_l);
|
||||
Py_DECREF(res_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(opt);
|
||||
free(what);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
||||
Py_DECREF(res_opt);
|
||||
Py_DECREF(res_what);
|
||||
return result;
|
||||
}
|
||||
|
||||
// long i = L - s, j = t - s, k = l - t
|
||||
inline long floating_index_in_array(long m_factor, long m, long i, long j,
|
||||
long k) {
|
||||
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
|
||||
(j * (j - 1)) / 2 + k;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
long sp;
|
||||
long r;
|
||||
long tp;
|
||||
} index_t;
|
||||
|
||||
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweigth");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
|
||||
if (!fwd_tmp) return NULL;
|
||||
|
||||
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
|
||||
if (!bwd_tmp) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
const long m_factor =
|
||||
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
|
||||
|
||||
// Defined for 0 <= s <= t <= l <= chain_length, for all m
|
||||
#undef OPT
|
||||
#define OPT(m, s, t, l) \
|
||||
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
||||
(l) - (t))]
|
||||
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
|
||||
|
||||
#undef WHAT
|
||||
#define WHAT(m, s, t, l) \
|
||||
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
||||
(l) - (t))]
|
||||
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
|
||||
|
||||
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
|
||||
double total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
partialSumsFW[i] = total;
|
||||
total += fw[i];
|
||||
}
|
||||
partialSumsFW[chain_length] = total;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
||||
OPT(m, i, i, i) = fw[i] + bw[i];
|
||||
else
|
||||
OPT(m, i, i, i) = INFINITY;
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long d = 1; d <= chain_length; ++d) { // d = l - s
|
||||
for (long s = 0; s <= chain_length - d; ++s) {
|
||||
long l = s + d;
|
||||
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
|
||||
long memNullSecond = 0;
|
||||
for (long j = s + 1; j < l; ++j) {
|
||||
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
|
||||
if (val > memNullSecond) memNullSecond = val;
|
||||
}
|
||||
for (long t = s; t <= l; ++t) {
|
||||
double chainCost = INFINITY;
|
||||
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
|
||||
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
|
||||
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
|
||||
}
|
||||
double bestLeafCost = INFINITY;
|
||||
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
|
||||
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
|
||||
for (long r = s; r <= t; ++r)
|
||||
if (cw[s] <= cw[r])
|
||||
for (long tp = t + 1; tp <= l; ++tp)
|
||||
for (long sp = r + 1; sp <= tp; ++sp) {
|
||||
long mp = m - cw[r] + cw[s];
|
||||
assert(mp >= 0);
|
||||
if (mp >= cw[sp]) {
|
||||
double value = partialSumsFW[sp] - partialSumsFW[s] +
|
||||
OPT(mp - cw[sp], sp, tp, l) +
|
||||
OPT(mp, r, t, tp - 1);
|
||||
if (value < bestLeafCost) {
|
||||
bestLeafCost = value;
|
||||
bestLeaf.sp = sp;
|
||||
bestLeaf.r = r;
|
||||
bestLeaf.tp = tp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
|
||||
OPT(m, s, t, l) = bestLeafCost;
|
||||
WHAT(m, s, t, l).sp = bestLeaf.sp;
|
||||
WHAT(m, s, t, l).r = bestLeaf.r;
|
||||
WHAT(m, s, t, l).tp = bestLeaf.tp;
|
||||
} else {
|
||||
OPT(m, s, t, l) = chainCost;
|
||||
WHAT(m, s, t, l).sp = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
free(fwd_tmp);
|
||||
free(bwd_tmp);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
PyObject* res_what = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
PyObject* res_what_m = PyDict_New();
|
||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
||||
for (long s = 0; s <= chain_length; ++s)
|
||||
for (long t = s; t <= chain_length; ++t)
|
||||
for (long l = t; l <= chain_length; ++l) {
|
||||
PyObject* key = Py_BuildValue("(lll)", s, t, l);
|
||||
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
|
||||
PyDict_SetItem(res_opt_m, key, value_opt);
|
||||
PyObject* value_what = true_tuple;
|
||||
index_t* idx_what = &WHAT(m, s, t, l);
|
||||
if (idx_what->sp >= 0)
|
||||
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
|
||||
idx_what->r, idx_what->tp);
|
||||
PyDict_SetItem(res_what_m, key, value_what);
|
||||
if (value_what != true_tuple) Py_DECREF(value_what);
|
||||
Py_DECREF(key);
|
||||
Py_DECREF(value_opt);
|
||||
}
|
||||
}
|
||||
|
||||
Py_DECREF(true_tuple);
|
||||
|
||||
free(opt);
|
||||
free(what);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
||||
Py_DECREF(res_opt);
|
||||
Py_DECREF(res_what);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
|
||||
PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweigth");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
// TODO: Can be optimized by only allocating memory for l >= i
|
||||
// TODO: float / int instead of double / long ?
|
||||
#undef OPT
|
||||
#define OPT(m, i, l) \
|
||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
double* opt = (double*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
||||
|
||||
// Compute partial sums
|
||||
double* sumfw = (double*)calloc(chain_length, sizeof(double));
|
||||
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
|
||||
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
|
||||
|
||||
double total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
total += fw[i];
|
||||
sumfw[i] = total;
|
||||
}
|
||||
|
||||
total = 0;
|
||||
for (long i = 0; i < chain_length + 1; ++i) {
|
||||
total += bw[i];
|
||||
sumbw[i] = total;
|
||||
}
|
||||
|
||||
total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
total += sumfw[i];
|
||||
sumsumfw[i] = total;
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
|
||||
OPT(m, i, i) = bw[i];
|
||||
else
|
||||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
if (i < chain_length) {
|
||||
long maxC = fmaxl(cw[i], cw[i + 1]);
|
||||
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
|
||||
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
|
||||
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
|
||||
else
|
||||
OPT(m, i, i + 1) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i + 2 <= chain_length; ++i) {
|
||||
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
|
||||
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
|
||||
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
|
||||
for (long l = i + 2; l <= chain_length; ++l) {
|
||||
maxCW_il = fmax(maxCW_il, cw[l + 1]);
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
|
||||
long mmin = fmaxl(mminCst, maxCostFWD);
|
||||
if ((m >= mmin)) {
|
||||
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
|
||||
noCheckpointCost +=
|
||||
sumsumfw[l - 1] -
|
||||
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
|
||||
|
||||
double valueCost = INFINITY;
|
||||
if (m >= cw[i]) {
|
||||
double sumFwds = 0;
|
||||
for (long j = i + 1; j < l; ++j) {
|
||||
sumFwds += fw[j - 1];
|
||||
valueCost = fmin(
|
||||
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
|
||||
}
|
||||
}
|
||||
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
|
||||
} else
|
||||
OPT(m, i, l) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(sumfw);
|
||||
free(sumbw);
|
||||
free(sumsumfw);
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
PyObject* res_opt_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
||||
for (long l = i; l <= chain_length; ++l) {
|
||||
PyObject* res_l = PyLong_FromLong(l - i);
|
||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
||||
Py_DECREF(res_opt_m_i_l);
|
||||
Py_DECREF(res_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(opt);
|
||||
|
||||
return res_opt;
|
||||
}
|
||||
|
||||
static PyMethodDef dynamic_programs_methods[] = {
|
||||
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table with the persistent algorithm."},
|
||||
{"floating_compute_table", floating_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table with the floating algorithm."},
|
||||
{"griewank_heterogeneous_compute_table",
|
||||
griewank_heterogeneous_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table for the Griewank Heterogeneous Model."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef dynamic_programs_module = {
|
||||
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
|
||||
NULL, /* module documentation, may be NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
dynamic_programs_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
|
||||
return PyModule_Create(&dynamic_programs_module);
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
from typing import List, Any
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.profiler import is_inplace
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
# unchanged throughout the whole model, it will be used several times by
|
||||
# different blocks of model, so that it is hard for us to linearize the graph
|
||||
# when we encounter those kinds of nodes. We let users to annotate some of the
|
||||
# input as common node, such as attention mask, and the followings are some of
|
||||
# the ops that could actually be seen as common nodes. With our common node prop,
|
||||
# we could find some of the "real" common nodes (e.g. the real attention mask
|
||||
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
||||
# nodes or it's op belongs to the following operations, we view this node as a
|
||||
# newly born common node.
|
||||
# List of target name that could be seen as common node
|
||||
COPS = ["getattr", "getitem", "size"]
|
||||
|
||||
|
||||
def _is_cop(target: Any) -> bool:
|
||||
"""Check if an op could be seen as common node
|
||||
|
||||
Args:
|
||||
target (Any): node target
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
if isinstance(target, str):
|
||||
return target in COPS
|
||||
else:
|
||||
return target.__name__ in COPS
|
||||
|
||||
|
||||
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||
"""Linearizing the graph
|
||||
|
||||
Args:
|
||||
gm (GraphModule): GraphModule derived by tracing
|
||||
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
|
||||
|
||||
Returns:
|
||||
List[List[Node]]: List of list, each inside list of Node presents
|
||||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
We merge the inplace ops into the previous node.
|
||||
"""
|
||||
|
||||
def _is_sink() -> bool:
|
||||
"""Check if we can free all dependencies
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if cnode:
|
||||
for name in cnode:
|
||||
try:
|
||||
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"common node {name} is not an input of the model"
|
||||
except StopIteration:
|
||||
raise ValueError(f"common node name {name} not in graph")
|
||||
|
||||
else:
|
||||
cnode = []
|
||||
|
||||
deps = {}
|
||||
linearized_nodes = []
|
||||
region = []
|
||||
|
||||
for n in gm.graph.nodes:
|
||||
if n.op != "placeholder" and n.op != "output":
|
||||
for n_par in n._input_nodes:
|
||||
if n_par.op != "placeholder" and n_par.name not in cnode:
|
||||
deps[n_par] -= 1
|
||||
region.append(n)
|
||||
|
||||
# if the node could free all dependencies in graph
|
||||
# we could begin a new node
|
||||
if _is_sink():
|
||||
linearized_nodes.append(region)
|
||||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
|
||||
cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
||||
return linearized_nodes
|
|
@ -1,270 +0,0 @@
|
|||
import math
|
||||
|
||||
|
||||
def _discretize(mem_unit, values):
|
||||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
|
||||
self.fweight = fw
|
||||
self.bweight = bw
|
||||
self.cweight = cw
|
||||
self.cbweight = cbw
|
||||
self.fwd_mem_tmp = ftmp
|
||||
self.bwd_mem_tmp = btmp
|
||||
self.length = len(fw)
|
||||
if check and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
|
||||
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
|
||||
self.bwd_mem_tmp[i]))
|
||||
i = self.length
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
def _discretize(self, mem_unit):
|
||||
self.cweight = _discretize(mem_unit, self.cweight)
|
||||
self.cbweight = _discretize(mem_unit, self.cbweight)
|
||||
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
|
||||
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
|
||||
|
||||
|
||||
class Operation:
|
||||
|
||||
def shift(self, value):
|
||||
if type(self.index) is tuple:
|
||||
self.index = tuple(x + value for x in self.index)
|
||||
else:
|
||||
self.index += value
|
||||
|
||||
|
||||
class Offload(Operation):
|
||||
|
||||
def __init__(self, index, has_bar=False) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Off"
|
||||
self.has_bar = has_bar
|
||||
if self.has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
|
||||
class Prefetch(Operation):
|
||||
|
||||
def __init__(self, index, has_bar=False) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Pre"
|
||||
self.has_bar = has_bar
|
||||
if self.has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
|
||||
class Forward(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.name = "F"
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.fweight[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class ForwardEnable(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "Fe"
|
||||
|
||||
|
||||
class ForwardNograd(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "Fn"
|
||||
|
||||
|
||||
class ForwardCheck(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "CF"
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
def __repr__(self):
|
||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
|
||||
|
||||
def isForward(op):
|
||||
return type(op) is Forward or type(op) is Forwards
|
||||
|
||||
|
||||
class Backward(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return "B_{i}".format(i=self.index)
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.bweight[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "L"
|
||||
|
||||
def cost(self, chain):
|
||||
return 0
|
||||
|
||||
|
||||
class MemoryAccess(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
return 0
|
||||
|
||||
|
||||
class WriteMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "WM"
|
||||
|
||||
|
||||
class ReadMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "RM"
|
||||
|
||||
|
||||
class DiscardMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "DM"
|
||||
|
||||
|
||||
class Function:
|
||||
|
||||
def __init__(self, name, *args):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.str_args = ','.join(str(v) for v in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}({args})".format(n=self.name, args=self.str_args)
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(self, function):
|
||||
self.sequence = [] #List of Operation and Sequence
|
||||
self.function = function #Description the function (name and parameters)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.list_operations())
|
||||
|
||||
def list_operations(self):
|
||||
op_list = []
|
||||
for x in self.sequence:
|
||||
if isinstance(x, Operation):
|
||||
op_list.append(x)
|
||||
else:
|
||||
assert isinstance(x, Sequence)
|
||||
op_list += x.list_operations()
|
||||
return op_list
|
||||
|
||||
def insert(self, operation):
|
||||
self.sequence.append(operation)
|
||||
|
||||
def remove(self, operation_index):
|
||||
del self.sequence[operation_index]
|
||||
|
||||
def insert_sequence(self, sequence):
|
||||
self.sequence.append(sequence)
|
||||
|
||||
def shift(self, value):
|
||||
for x in self.sequence:
|
||||
x.shift(value)
|
||||
return self
|
||||
|
||||
def remove_useless_write(self):
|
||||
if self.sequence:
|
||||
if isinstance(self.sequence[0], WriteMemory):
|
||||
self.remove(0)
|
||||
return self
|
||||
|
||||
def get_makespan(self, chain):
|
||||
return sum(op.cost(chain) for op in self.list_operations())
|
||||
|
||||
def without_suffix(self):
|
||||
ops = self.list_operations()
|
||||
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
|
||||
try:
|
||||
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
|
||||
except ValueError:
|
||||
last_idx = -1
|
||||
if last_idx == end_of_first_phase - 1:
|
||||
return (self, None)
|
||||
chain_length = ops[end_of_first_phase -
|
||||
1].index ## Some assumption here about the sequence (finishes with Forward_L
|
||||
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
|
||||
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
|
||||
for i in range(last_idx + 1):
|
||||
result.insert(ops[i])
|
||||
result.insert(Loss())
|
||||
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
|
||||
position = end_of_first_phase + 1 + (chain_length - i)
|
||||
assert type(ops[position]) is Backward
|
||||
assert ops[position].index == i
|
||||
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
|
||||
result.insert(ops[i])
|
||||
return (result, start_of_fwd_enable_chain)
|
Loading…
Reference in New Issue