mirror of https://github.com/hpcaitech/ColossalAI
parent
55dcd3051a
commit
b42d3d28ed
|
@ -1,4 +0,0 @@
|
||||||
from .ckpt_solver_chen import chen_greedy
|
|
||||||
from .linearize import linearize
|
|
||||||
from .ckpt_solver_rotor import solver_rotor
|
|
||||||
from .ckpt_solver_pofo import solver_pofo
|
|
|
@ -1,15 +0,0 @@
|
||||||
from setuptools import setup, Extension
|
|
||||||
import os
|
|
||||||
|
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
ext_modules = [Extension(
|
|
||||||
'dynamic_programs_C_version',
|
|
||||||
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
|
|
||||||
)]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='rotor c extension',
|
|
||||||
version='0.1',
|
|
||||||
description='rotor c extension for faster dp computing',
|
|
||||||
ext_modules=ext_modules,
|
|
||||||
)
|
|
|
@ -1,98 +0,0 @@
|
||||||
import math
|
|
||||||
from typing import List, Set, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.fx import GraphModule, Node
|
|
||||||
|
|
||||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
|
||||||
|
|
||||||
__all__ = ['chen_greedy']
|
|
||||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
|
||||||
|
|
||||||
|
|
||||||
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
|
||||||
"""
|
|
||||||
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def is_sink():
|
|
||||||
"""
|
|
||||||
If we can free all memories when executing a certain node, it is a sink.
|
|
||||||
"""
|
|
||||||
return not sum((v for k, v in deps.items()))
|
|
||||||
|
|
||||||
deps = {}
|
|
||||||
ckpt_nodes = []
|
|
||||||
for n in gm.graph.nodes:
|
|
||||||
for n_par in n._input_nodes:
|
|
||||||
deps[n_par] -= 1 # free memory and dependencies
|
|
||||||
|
|
||||||
# We can only put act_ckpt on these nodes
|
|
||||||
if n.op in CKPT_OP and is_sink():
|
|
||||||
ckpt_nodes.append(n)
|
|
||||||
deps[n] = len(n.users) # add dependencies for future executions
|
|
||||||
return ckpt_nodes
|
|
||||||
|
|
||||||
|
|
||||||
def chen_greedy(gm: GraphModule) -> GraphModule:
|
|
||||||
"""
|
|
||||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
|
||||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
model = resnet18()
|
|
||||||
input_sample = torch.rand(4, 3, 224, 224)
|
|
||||||
gm = symbolic_trace(model)
|
|
||||||
MetaInfoProp(gm).run(input_sample)
|
|
||||||
gm = chen_greedy(gm)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gm (GraphModule): The module to add checkpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
def grid_search(num_grids: int = 6) -> Set:
|
|
||||||
"""
|
|
||||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
|
||||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
|
||||||
"""
|
|
||||||
_, b_approx = run_chen_greedy(0)
|
|
||||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
|
||||||
b_opt = math.inf
|
|
||||||
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
|
||||||
ckpt_intv, b_approx = run_chen_greedy(b)
|
|
||||||
if b_approx < b_opt:
|
|
||||||
b_opt = b_approx
|
|
||||||
ckpt_opt = ckpt_intv
|
|
||||||
return ckpt_opt
|
|
||||||
|
|
||||||
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
|
||||||
"""
|
|
||||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
|
||||||
"""
|
|
||||||
ckpt_nodes = _all_potential_ckpt_nodes(gm)
|
|
||||||
ckpt_intv = []
|
|
||||||
temp = 0
|
|
||||||
x = 0
|
|
||||||
y = 0
|
|
||||||
prev_idx = 2
|
|
||||||
for (idx, n) in enumerate(gm.graph.nodes):
|
|
||||||
n: Node
|
|
||||||
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
|
||||||
y = max(y, temp)
|
|
||||||
if temp > b and n in ckpt_nodes:
|
|
||||||
x += calculate_fwd_in(n)
|
|
||||||
temp = 0
|
|
||||||
ckpt_intv.append((prev_idx, idx + 1))
|
|
||||||
prev_idx = idx + 1
|
|
||||||
return ckpt_intv, math.floor(math.sqrt(x * y))
|
|
||||||
|
|
||||||
gm.graph.lint() # make sure nodes are in topological order
|
|
||||||
ckpt = grid_search(num_grids=6)
|
|
||||||
node_list = list(gm.graph.nodes)
|
|
||||||
for i, seg in enumerate(ckpt):
|
|
||||||
for idx in range(*seg):
|
|
||||||
n = node_list[idx]
|
|
||||||
if n.op in CKPT_OP:
|
|
||||||
setattr(n, 'activation_checkpoint', i)
|
|
||||||
gm.recompile()
|
|
||||||
return gm
|
|
|
@ -1,537 +0,0 @@
|
||||||
import copy
|
|
||||||
import math
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from colossalai.fx import is_compatible_with_meta
|
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import \
|
|
||||||
_find_nested_ckpt_regions
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
||||||
from colossalai.fx.profiler import parameter_size
|
|
||||||
from torch.fx import GraphModule, Node
|
|
||||||
|
|
||||||
from .linearize import linearize
|
|
||||||
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
|
|
||||||
Sequence)
|
|
||||||
|
|
||||||
INF = float("inf")
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_flops(chain: Chain, flops) -> Chain:
|
|
||||||
"""
|
|
||||||
Normalize flops
|
|
||||||
"""
|
|
||||||
for i in range(chain.length):
|
|
||||||
chain.fweight[i] /= flops
|
|
||||||
chain.bweight[i] /= flops
|
|
||||||
|
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
class PofoTable:
|
|
||||||
"""PofoTable
|
|
||||||
The PofoTable contains the necessary components to store intermediate results
|
|
||||||
of dynamic programming and the operations alone the way.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, chain_length: int, mem_slots: int):
|
|
||||||
"""Init pofo table
|
|
||||||
The pofo table contains two tables, opt and what, indicating values and
|
|
||||||
operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chain_length (int): chain length
|
|
||||||
mem_slots (int): number of memory slots
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.length = chain_length
|
|
||||||
self.mem_slots = mem_slots
|
|
||||||
|
|
||||||
# initializing tables
|
|
||||||
# the first bool indicates whether the input has bar
|
|
||||||
# opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
|
|
||||||
# what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
|
|
||||||
# where is_enable indicates whether we enable the gradient, is_offload indicates whether we
|
|
||||||
# offload the input, index indicates the end of F_\empty sequence if is_enable = False
|
|
||||||
self.opt = {
|
|
||||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
|
||||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
|
||||||
}
|
|
||||||
self.what = {
|
|
||||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
|
||||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_value(self, state, table, default):
|
|
||||||
i, act_size, df, db, input_has_bar = state
|
|
||||||
if act_size + df > self.mem_slots or act_size + db > self.mem_slots:
|
|
||||||
return default
|
|
||||||
|
|
||||||
try:
|
|
||||||
return table[input_has_bar][i][act_size][(df, db)]
|
|
||||||
except KeyError:
|
|
||||||
print(f"state not found {state}")
|
|
||||||
|
|
||||||
def get_opt(self, state):
|
|
||||||
return self._get_value(state, self.opt, INF)
|
|
||||||
|
|
||||||
def get_what(self, state):
|
|
||||||
return self._get_value(state, self.what, INF)
|
|
||||||
|
|
||||||
def set_value(self, state, opt, what):
|
|
||||||
i, act_size, df, db, input_has_bar = state
|
|
||||||
self.opt[input_has_bar][i][act_size][(df, db)] = opt
|
|
||||||
self.what[input_has_bar][i][act_size][(df, db)] = what
|
|
||||||
|
|
||||||
|
|
||||||
class PofoSolver:
|
|
||||||
"""PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
|
||||||
The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
|
|
||||||
and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
|
|
||||||
rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None:
|
|
||||||
self.chain = chain
|
|
||||||
self.length = chain.length
|
|
||||||
self.max_memory = max_memory
|
|
||||||
self.mem_slots = mem_slots
|
|
||||||
self.mem_unit = max_memory / mem_slots
|
|
||||||
self.bandwidth = bandwidth
|
|
||||||
|
|
||||||
self.disc_chain = copy.deepcopy(self.chain)
|
|
||||||
self.disc_chain._discretize(self.mem_unit)
|
|
||||||
|
|
||||||
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
|
|
||||||
self._compute_pofo_table()
|
|
||||||
|
|
||||||
def _discretize(self, *values) -> Tuple:
|
|
||||||
return tuple(math.ceil(value / self.mem_unit) for value in values)
|
|
||||||
|
|
||||||
def _undiscretize(self, *discrete_values) -> Tuple:
|
|
||||||
if len(discrete_values) == 1:
|
|
||||||
return discrete_values[0] * self.mem_unit
|
|
||||||
else:
|
|
||||||
return tuple(d * self.mem_unit for d in discrete_values)
|
|
||||||
|
|
||||||
def _mmax_all(self, idx: int):
|
|
||||||
"""
|
|
||||||
Calculate the maximum memory usage of Fi_all
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx]
|
|
||||||
|
|
||||||
def _mmax_b(self, idx: int):
|
|
||||||
"""
|
|
||||||
Calculate the maximum memory usage of Bi
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.chain.cbweight[idx +
|
|
||||||
1] + self.chain.cweight[idx +
|
|
||||||
1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx]
|
|
||||||
|
|
||||||
def _mmax_ng(self, i: int, j: int):
|
|
||||||
"""
|
|
||||||
Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
|
|
||||||
"""
|
|
||||||
|
|
||||||
res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j]
|
|
||||||
if j > i:
|
|
||||||
res += self.chain.cweight[j]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _rotor_estimated_bwd(self, i, j, m, delta):
|
|
||||||
compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j]
|
|
||||||
comm = delta / self.bandwidth
|
|
||||||
return (max(compute, comm) + compute + comm) / 2
|
|
||||||
|
|
||||||
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
|
|
||||||
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
|
|
||||||
|
|
||||||
def _common_values_enable(self, state: Tuple):
|
|
||||||
|
|
||||||
idx, act_size, df, db, input_has_bar = state
|
|
||||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
|
||||||
mf = act_size + df + input_size
|
|
||||||
mb = act_size + db + input_size
|
|
||||||
mem_avail = self.max_memory - act_size - input_size
|
|
||||||
f_usage = self._mmax_all(idx)
|
|
||||||
b_usage = self._mmax_b(idx)
|
|
||||||
|
|
||||||
# infeasible
|
|
||||||
if f_usage > mem_avail or b_usage > mem_avail:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# calculate idle time
|
|
||||||
eps_f_beta = max(0, f_usage - self.max_memory + mf)
|
|
||||||
eps_b_beta = max(0, b_usage - self.max_memory + mb)
|
|
||||||
idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth
|
|
||||||
|
|
||||||
# calculate offload and prefetch data
|
|
||||||
offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta
|
|
||||||
prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta
|
|
||||||
|
|
||||||
# total_time
|
|
||||||
total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time
|
|
||||||
|
|
||||||
return (offload_data, prefetch_data, total_time, idle_time)
|
|
||||||
|
|
||||||
def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False):
|
|
||||||
|
|
||||||
i, act_size, df, db, input_has_bar = state
|
|
||||||
|
|
||||||
# compute new epsilon_tmp and sum_fwds
|
|
||||||
if iterative:
|
|
||||||
self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds)
|
|
||||||
self.sum_fwds += self.chain.fweight[j]
|
|
||||||
else:
|
|
||||||
self.epsilon_tmp = max(
|
|
||||||
self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1))
|
|
||||||
self.sum_fwds = sum(self.chain.fweight[i:j + 1])
|
|
||||||
|
|
||||||
input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]
|
|
||||||
mf = act_size + df + input_size
|
|
||||||
mem_avail = self.max_memory - act_size - input_size
|
|
||||||
|
|
||||||
# if infeasible
|
|
||||||
if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail:
|
|
||||||
return None
|
|
||||||
|
|
||||||
eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf)
|
|
||||||
offload_data = self.sum_fwds * self.bandwidth + eps_f_beta
|
|
||||||
|
|
||||||
# TODO: Implement the precise backward recompute sequence mentioned in the paper
|
|
||||||
# currently we will use an approximate way to get the backward time
|
|
||||||
time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db)
|
|
||||||
|
|
||||||
prefetch_data = time_backward * self.bandwidth
|
|
||||||
idle_time = eps_f_beta / self.bandwidth
|
|
||||||
total_time = self.sum_fwds + idle_time + time_backward
|
|
||||||
|
|
||||||
return (offload_data, prefetch_data, total_time, idle_time)
|
|
||||||
|
|
||||||
def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple:
|
|
||||||
"""Generate new values for next state
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (Tuple): undiscretized states
|
|
||||||
do_offload (bool): bool type indicates whether we need to do offload
|
|
||||||
common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: (new_act_size, new_df, new_db)
|
|
||||||
"""
|
|
||||||
idx, act_size, df, db, input_has_bar = state
|
|
||||||
offload_data, prefetch_data, *_ = common_values
|
|
||||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
|
||||||
if do_offload:
|
|
||||||
new_act_size = act_size
|
|
||||||
new_df = max(0, df + input_size - offload_data)
|
|
||||||
new_db = max(0, db - prefetch_data) + input_size
|
|
||||||
else:
|
|
||||||
new_act_size = act_size + input_size
|
|
||||||
new_df = max(0, df - offload_data)
|
|
||||||
new_db = max(0, db - prefetch_data)
|
|
||||||
|
|
||||||
return (new_act_size, new_df, new_db)
|
|
||||||
|
|
||||||
def _compute_pofo_table(self):
|
|
||||||
self.table = PofoTable(self.length, self.mem_slots)
|
|
||||||
|
|
||||||
# initializing the loss
|
|
||||||
for act_size in range(self.mem_slots + 1):
|
|
||||||
for df in range(self.mem_slots - act_size + 1):
|
|
||||||
for db in range(self.mem_slots - act_size + 1):
|
|
||||||
# undiscretize for idle time calculation
|
|
||||||
origin_values = self._undiscretize(act_size, df, db)
|
|
||||||
|
|
||||||
for input_has_bar in (False, True):
|
|
||||||
disc_state = (self.length, act_size, df, db, input_has_bar)
|
|
||||||
state = (self.length, *origin_values, input_has_bar)
|
|
||||||
common_values = self._common_values_enable(state)
|
|
||||||
|
|
||||||
# if no feasible choice
|
|
||||||
if common_values is None:
|
|
||||||
self.table.set_value(disc_state, INF, None)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if there is feasible choice
|
|
||||||
new_act_size, new_df, new_db = self._new_values(state, False, common_values)
|
|
||||||
eps_g = (new_df + new_db) / self.bandwidth
|
|
||||||
total_time = common_values[2] + eps_g
|
|
||||||
self.table.set_value(disc_state, total_time, (True, False))
|
|
||||||
|
|
||||||
# main loop
|
|
||||||
for i in reversed(range(self.length)):
|
|
||||||
for act_size in range(self.mem_slots + 1):
|
|
||||||
for df in range(self.mem_slots - act_size + 1):
|
|
||||||
for db in range(self.mem_slots - act_size + 1):
|
|
||||||
# undiscretize for idle time calculation
|
|
||||||
origin_values = self._undiscretize(act_size, df, db)
|
|
||||||
|
|
||||||
for input_has_bar in (False, True):
|
|
||||||
best_result = INF
|
|
||||||
best_choice = None
|
|
||||||
disc_state = (i, act_size, df, db, input_has_bar)
|
|
||||||
state = (i, *origin_values, input_has_bar)
|
|
||||||
|
|
||||||
# case 1: start with F_all
|
|
||||||
vals_enable = self._common_values_enable(state)
|
|
||||||
if vals_enable is not None:
|
|
||||||
for do_offload in (True, False):
|
|
||||||
new_state = self._new_values(state, do_offload, vals_enable)
|
|
||||||
new_state = (i + 1, *self._discretize(*new_state), True)
|
|
||||||
total_time = vals_enable[2]
|
|
||||||
results_all = self.table.get_opt(new_state) + total_time
|
|
||||||
if results_all < best_result:
|
|
||||||
best_result = results_all
|
|
||||||
best_choice = (True, do_offload)
|
|
||||||
|
|
||||||
# case 2: start with F_ck
|
|
||||||
self.sum_fwds = 0
|
|
||||||
self.epsilon_tmp = 0
|
|
||||||
for j in range(i, self.length):
|
|
||||||
vals_nograd = self._common_values_nograd(state, j, True)
|
|
||||||
|
|
||||||
# if infeasible
|
|
||||||
if vals_nograd is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for do_offload in (True, False):
|
|
||||||
new_state = self._new_values(state, do_offload, vals_nograd)
|
|
||||||
new_state = (j + 1, *self._discretize(*new_state), False)
|
|
||||||
total_time = vals_nograd[2]
|
|
||||||
result_nograd = total_time + self.table.get_opt(new_state)
|
|
||||||
if result_nograd < best_result:
|
|
||||||
best_result = result_nograd
|
|
||||||
best_choice = (False, do_offload, j)
|
|
||||||
|
|
||||||
self.table.set_value(disc_state, best_result, best_choice)
|
|
||||||
|
|
||||||
def pofo_rec(self, disc_state):
|
|
||||||
i, act_size, df, db, input_has_bar = disc_state
|
|
||||||
result = Sequence(Function("pofo", *disc_state))
|
|
||||||
what = self.table.get_what(disc_state)
|
|
||||||
state = self._undiscretize(act_size, df, db)
|
|
||||||
state = (i, *state, input_has_bar)
|
|
||||||
i, act_size, df, db, input_has_bar = state
|
|
||||||
|
|
||||||
if what is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# if loss
|
|
||||||
if i == self.length:
|
|
||||||
result.insert(Loss())
|
|
||||||
return result
|
|
||||||
|
|
||||||
if what[0]:
|
|
||||||
do_offload = what[1]
|
|
||||||
values = self._common_values_enable(state)
|
|
||||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
|
||||||
new_state = (i + 1, *new_state, True)
|
|
||||||
if do_offload:
|
|
||||||
result.insert(Offload(i, input_has_bar))
|
|
||||||
result.insert(ForwardEnable(i))
|
|
||||||
result.insert_sequence(self.pofo_rec(new_state))
|
|
||||||
if do_offload:
|
|
||||||
result.insert(Prefetch(i, input_has_bar))
|
|
||||||
result.insert(Backward(i))
|
|
||||||
|
|
||||||
else:
|
|
||||||
_, do_offload, j = what
|
|
||||||
values = self._common_values_nograd(state, j)
|
|
||||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
|
||||||
new_state = (j + 1, *new_state, False)
|
|
||||||
if do_offload:
|
|
||||||
result.insert(Offload(i, input_has_bar))
|
|
||||||
result.insert(ForwardCheck(i))
|
|
||||||
for k in range(i + 1, j + 1):
|
|
||||||
result.insert(ForwardNograd(k))
|
|
||||||
result.insert_sequence(self.pofo_rec(new_state))
|
|
||||||
if do_offload:
|
|
||||||
result.insert(Prefetch(i, input_has_bar))
|
|
||||||
m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i])
|
|
||||||
|
|
||||||
#TODO: Implement the precise backward recompute sequence mentioned in the paper
|
|
||||||
result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
|
||||||
op_list = sequence.list_operations()
|
|
||||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
|
||||||
fwd_list = op_list[:op_list.index(loss_op)]
|
|
||||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
|
||||||
ckpt_idx = 0
|
|
||||||
in_ckpt = False
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
# forward annotation
|
|
||||||
for op in fwd_list:
|
|
||||||
if in_ckpt:
|
|
||||||
if isinstance(op, ForwardNograd):
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardEnable):
|
|
||||||
in_ckpt = False
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardCheck):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = [op.index]
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(op, ForwardCheck):
|
|
||||||
in_ckpt = True
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
# annotate the backward if there is any nested activation checkpoint
|
|
||||||
in_recompute = False
|
|
||||||
for op in bwd_list:
|
|
||||||
if in_recompute:
|
|
||||||
if isinstance(op, ForwardNograd):
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardEnable):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardCheck):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = [op.index]
|
|
||||||
|
|
||||||
elif isinstance(op, Backward):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
in_recompute = False
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not isinstance(op, Backward):
|
|
||||||
in_recompute = True
|
|
||||||
ckpt_idx = 0
|
|
||||||
ckpt_region = []
|
|
||||||
if isinstance(op, ForwardCheck):
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
# postprocess, make sure every activation checkpoint label in the
|
|
||||||
# same activation checkpoint region (level = 0) has the same length
|
|
||||||
op_list = []
|
|
||||||
for node in node_list:
|
|
||||||
op_list += node
|
|
||||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
|
||||||
for (start_idx, end_idx) in ckpt_regions:
|
|
||||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
|
||||||
for idx in range(start_idx, end_idx + 1):
|
|
||||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
|
||||||
|
|
||||||
# annotate the offload
|
|
||||||
offload_idx = 0
|
|
||||||
for idx, op in enumerate(fwd_list):
|
|
||||||
if isinstance(op, Offload):
|
|
||||||
# corner case: offload input
|
|
||||||
if op.index == 0:
|
|
||||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
|
||||||
for n in node_list[op.index]:
|
|
||||||
setattr(n, "activation_offload", True)
|
|
||||||
else:
|
|
||||||
for n in node_list[op.index]:
|
|
||||||
setattr(n, "activation_offload", (offload_idx, True, False))
|
|
||||||
offload_idx += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
if op.has_bar:
|
|
||||||
# annotate previous node
|
|
||||||
if hasattr(node_list[op.index - 1][0], "activation_offload"):
|
|
||||||
for n in node_list[op.index - 1]:
|
|
||||||
n.activation_offload[-1] = True
|
|
||||||
else:
|
|
||||||
for n in node_list[op.index - 1]:
|
|
||||||
setattr(n, "activation_offload", [offload_idx, False, True])
|
|
||||||
|
|
||||||
offload_idx += 1
|
|
||||||
|
|
||||||
# annotate this node
|
|
||||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
|
||||||
for n in node_list[op.index]:
|
|
||||||
setattr(n, "activation_offload", True)
|
|
||||||
else:
|
|
||||||
for n in node_list[op.index]:
|
|
||||||
setattr(n, "activation_offload", [offload_idx, True, False])
|
|
||||||
|
|
||||||
offload_idx += 1
|
|
||||||
|
|
||||||
|
|
||||||
def solver_pofo(gm: ColoGraphModule,
|
|
||||||
data,
|
|
||||||
bandwidth,
|
|
||||||
flops,
|
|
||||||
mem_limit: int,
|
|
||||||
mem_slots: int = 50,
|
|
||||||
cnode: List[str] = None,
|
|
||||||
eps: float = 0.0) -> ColoGraphModule:
|
|
||||||
"""Solver that combine offload and activation checkpoint
|
|
||||||
Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gm (ColoGraphModule): ColoGraphModule derived from tracer
|
|
||||||
data: input of the model
|
|
||||||
bandwidth: offload bandwidth, unit Byte/s
|
|
||||||
flops: FLOPS of device, unit FLOPs/s
|
|
||||||
mem_limit (int): memory limit, unit Byte
|
|
||||||
mem_slots (int, optional): number of memory slots. Defaults to 500.
|
|
||||||
cnode (List[str], optional): common node for linearize. Defaults to None.
|
|
||||||
eps (float, optional): epsilon for memory decay. Defaults to 0.02.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ColoGraphModule: annotated graph module
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_list = linearize(gm, cnode)
|
|
||||||
mem_limit -= parameter_size(gm)
|
|
||||||
|
|
||||||
# prepare data
|
|
||||||
if is_compatible_with_meta():
|
|
||||||
from colossalai.fx.profiler import MetaTensor
|
|
||||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
|
||||||
MetaInfoProp(gm).run(data)
|
|
||||||
chain: Chain = _construct_chain(node_list, data)
|
|
||||||
chain = _normalize_flops(chain, flops)
|
|
||||||
# currently we view loss as an op without expense
|
|
||||||
chain.cbweight.append(0)
|
|
||||||
chain.cweight.append(0)
|
|
||||||
chain.fwd_mem_tmp.append(0)
|
|
||||||
chain.bwd_mem_tmp.append(0)
|
|
||||||
chain.fweight.append(0)
|
|
||||||
chain.bweight.append(0)
|
|
||||||
|
|
||||||
solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots)
|
|
||||||
first_state = (0, 0, 0, 0, False)
|
|
||||||
sequence = solver.pofo_rec(first_state)
|
|
||||||
if sequence == None:
|
|
||||||
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
|
|
||||||
|
|
||||||
_annotate_from_pofo_sequence(sequence, node_list)
|
|
||||||
setattr(gm, "__sequence__", sequence)
|
|
||||||
return gm
|
|
|
@ -1,436 +0,0 @@
|
||||||
import math
|
|
||||||
import sys
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from torch.fx import Node
|
|
||||||
|
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
from .linearize import linearize
|
|
||||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
|
||||||
|
|
||||||
# global vairable to indicate whether the solver is failed
|
|
||||||
SOLVER_FAILED = False
|
|
||||||
|
|
||||||
|
|
||||||
# this is the python compute table code from rotor
|
|
||||||
# https://gitlab.inria.fr/hiepacs/rotor
|
|
||||||
# paper link: https://hal.inria.fr/hal-02352969
|
|
||||||
def _compute_table(chain: Chain, mmax) -> Tuple:
|
|
||||||
"""Returns the optimal table: a tuple containing:
|
|
||||||
Opt[m][lmin][lmax] with lmin = 0...chain.length
|
|
||||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
|
||||||
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
|
|
||||||
(False, j) if the optimal choice is a leaf checkpoint of length j
|
|
||||||
The computation uses dynamic programming"""
|
|
||||||
|
|
||||||
fw = chain.fweight + [0] ## forward time
|
|
||||||
bw = chain.bweight ## backward time, not used
|
|
||||||
cw = chain.cweight + [0] ## size of x (and of y)
|
|
||||||
cbw = chain.cbweight + [0] ## size of xbar
|
|
||||||
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
|
|
||||||
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
|
|
||||||
|
|
||||||
# Build table
|
|
||||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
|
||||||
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
|
||||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
|
||||||
|
|
||||||
# Initialize borders of the tables for lmax-lmin = 0
|
|
||||||
for m in range(mmax + 1):
|
|
||||||
for i in range(chain.length + 1):
|
|
||||||
#lmax-lmin = 0
|
|
||||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
|
|
||||||
if m >= limit: ## Equation (1)
|
|
||||||
opt[m][i][i] = fw[i] + bw[i]
|
|
||||||
else:
|
|
||||||
opt[m][i][i] = float("inf")
|
|
||||||
|
|
||||||
# Compute everything
|
|
||||||
for m in range(mmax + 1):
|
|
||||||
for d in range(1, chain.length + 1):
|
|
||||||
for i in range(chain.length + 1 - d):
|
|
||||||
# for idx in range(i+1, chain.length + 1):
|
|
||||||
idx = i + d
|
|
||||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
|
|
||||||
if idx > i + 1:
|
|
||||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
|
|
||||||
if m < mmin:
|
|
||||||
opt[m][i][idx] = float("inf")
|
|
||||||
else:
|
|
||||||
leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
|
|
||||||
for j in range(i + 1, idx + 1)
|
|
||||||
if m >= cw[j]]
|
|
||||||
if leaf_checkpoints:
|
|
||||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
|
||||||
else:
|
|
||||||
best_leaf = None
|
|
||||||
if m >= cbw[i + 1]:
|
|
||||||
chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
|
|
||||||
else:
|
|
||||||
chain_checkpoint = float("inf")
|
|
||||||
if best_leaf and best_leaf[1] <= chain_checkpoint:
|
|
||||||
opt[m][i][idx] = best_leaf[1]
|
|
||||||
what[m][i][idx] = (False, best_leaf[0])
|
|
||||||
else:
|
|
||||||
opt[m][i][idx] = chain_checkpoint
|
|
||||||
what[m][i][idx] = (True,)
|
|
||||||
return (opt, what)
|
|
||||||
|
|
||||||
|
|
||||||
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
|
||||||
""" chain : the class describing the AC graph
|
|
||||||
lmin : index of the first forward to execute
|
|
||||||
lmax : upper bound index of the last forward to execute (not included)
|
|
||||||
cmem : number of available memory slots
|
|
||||||
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
|
|
||||||
if cmem <= 0:
|
|
||||||
raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
|
|
||||||
opt, what = opt_table
|
|
||||||
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
|
|
||||||
if opt[cmem][lmin][lmax] == float("inf"):
|
|
||||||
# using logger to annonce that the solver is failed
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
|
||||||
lmax=lmax,
|
|
||||||
cmem=cmem))
|
|
||||||
|
|
||||||
# set global indicater SOLVER_FAILED to True
|
|
||||||
global SOLVER_FAILED
|
|
||||||
SOLVER_FAILED = True
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
if lmin == lmax:
|
|
||||||
if lmin == chain.length:
|
|
||||||
sequence.insert(Loss())
|
|
||||||
else:
|
|
||||||
sequence.insert(ForwardEnable(lmin))
|
|
||||||
sequence.insert(Backward(lmin))
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
if what[cmem][lmin][lmax][0]:
|
|
||||||
sequence.insert(ForwardEnable(lmin))
|
|
||||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
|
|
||||||
sequence.insert(Backward(lmin))
|
|
||||||
else:
|
|
||||||
j = what[cmem][lmin][lmax][1]
|
|
||||||
sequence.insert(ForwardCheck(lmin))
|
|
||||||
for k in range(lmin + 1, j):
|
|
||||||
sequence.insert(ForwardNograd(k))
|
|
||||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
|
|
||||||
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
def _fwd_xbar(node: List[Node]) -> int:
|
|
||||||
"""Get the forward xbar of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): List of torch.fx Node,
|
|
||||||
indicates a node in linearized graph
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: xbar size, unit Byte
|
|
||||||
"""
|
|
||||||
|
|
||||||
xbar = 0
|
|
||||||
for n in node:
|
|
||||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
|
||||||
return xbar
|
|
||||||
|
|
||||||
|
|
||||||
def _fwd_time(node: List[Node]) -> int:
|
|
||||||
"""Get the foward time of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): List of torch.fx Node,
|
|
||||||
indicates a node in linearized graph
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: foward time, extimated by flops count
|
|
||||||
"""
|
|
||||||
|
|
||||||
fwd_time = 0
|
|
||||||
for n in node:
|
|
||||||
# minimum flop count is needed
|
|
||||||
fwd_time += max(n.meta['fwd_flop'], 1)
|
|
||||||
return fwd_time
|
|
||||||
|
|
||||||
|
|
||||||
def _bwd_time(node: List[Node]) -> int:
|
|
||||||
"""Get the backward time of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): List of torch.fx Node,
|
|
||||||
indicates a node in linearized graph
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: backward time, extimated by flops count
|
|
||||||
"""
|
|
||||||
|
|
||||||
bwd_time = 0
|
|
||||||
for n in node:
|
|
||||||
# minimum flop count is needed
|
|
||||||
bwd_time += max(n.meta['bwd_flop'], 1)
|
|
||||||
return bwd_time
|
|
||||||
|
|
||||||
|
|
||||||
def _get_fwd_mem_tmp(node: List[Node]) -> int:
|
|
||||||
"""Get the forward temp memory of a node
|
|
||||||
This could be done by subtracting the saved activation from all output of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): List of torch.fx Node,
|
|
||||||
indicates a node in linearized graph
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: forward temp memory, unit Byte
|
|
||||||
"""
|
|
||||||
n = node[-1]
|
|
||||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
|
||||||
"""Get the backward temp memory of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): List of torch.fx Node,
|
|
||||||
indicates a node in linearized graph
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: backward temp memory, unit Byte
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_deps_size():
|
|
||||||
deps_size = 0
|
|
||||||
for k, v in deps.items():
|
|
||||||
k: Node
|
|
||||||
if v > 0:
|
|
||||||
deps_size += k.meta['bwd_mem_out']
|
|
||||||
if v == float('-inf'):
|
|
||||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
|
||||||
|
|
||||||
return deps_size
|
|
||||||
|
|
||||||
bwd_mem_tmp = 0
|
|
||||||
deps = {}
|
|
||||||
|
|
||||||
for n in reversed(node):
|
|
||||||
deps[n] = len(n.all_input_nodes)
|
|
||||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
|
||||||
|
|
||||||
for child in n.users:
|
|
||||||
if child in deps:
|
|
||||||
deps[child] -= 1
|
|
||||||
if deps[child] <= 0:
|
|
||||||
deps[child] = float('-inf') # free
|
|
||||||
|
|
||||||
return bwd_mem_tmp
|
|
||||||
|
|
||||||
|
|
||||||
def _construct_chain(node_list: List[List[Node]], input) -> Chain:
|
|
||||||
|
|
||||||
fwd_time = []
|
|
||||||
bwd_time = []
|
|
||||||
xbar_sizes = [activation_size(input)]
|
|
||||||
x_sizes = [activation_size(input)]
|
|
||||||
tmp_fwd = []
|
|
||||||
tmp_bwd = []
|
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
|
||||||
fwd_time.append(_fwd_time(node))
|
|
||||||
bwd_time.append(_bwd_time(node))
|
|
||||||
x_sizes.append(calculate_fwd_out(node[-1]))
|
|
||||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
|
||||||
tmp_fwd.append(_get_fwd_mem_tmp(node))
|
|
||||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
|
||||||
|
|
||||||
bwd_time.append(0)
|
|
||||||
|
|
||||||
# currently we view loss backward temp as zero
|
|
||||||
tmp_bwd.append(0)
|
|
||||||
|
|
||||||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
|
||||||
|
|
||||||
|
|
||||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
|
||||||
op_list = sequence.list_operations()
|
|
||||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
|
||||||
fwd_list = op_list[:op_list.index(loss_op)]
|
|
||||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
|
||||||
ckpt_idx = 0
|
|
||||||
in_ckpt = False
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
# forward annotation
|
|
||||||
for idx, op in enumerate(fwd_list, 0):
|
|
||||||
if in_ckpt:
|
|
||||||
if isinstance(op, ForwardNograd):
|
|
||||||
ckpt_region.append(idx)
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardEnable):
|
|
||||||
in_ckpt = False
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardCheck):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = [idx]
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(op, ForwardCheck):
|
|
||||||
in_ckpt = True
|
|
||||||
ckpt_region.append(idx)
|
|
||||||
|
|
||||||
# annotate the backward if there is any nested activation checkpoint
|
|
||||||
in_recompute = False
|
|
||||||
for op in bwd_list:
|
|
||||||
if in_recompute:
|
|
||||||
if isinstance(op, ForwardNograd):
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardEnable):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = []
|
|
||||||
|
|
||||||
elif isinstance(op, ForwardCheck):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
ckpt_idx += 1
|
|
||||||
ckpt_region = [op.index]
|
|
||||||
|
|
||||||
elif isinstance(op, Backward):
|
|
||||||
for node_idx in ckpt_region:
|
|
||||||
for n in node_list[node_idx]:
|
|
||||||
n.activation_checkpoint.append(ckpt_idx)
|
|
||||||
|
|
||||||
in_recompute = False
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not isinstance(op, Backward):
|
|
||||||
in_recompute = True
|
|
||||||
ckpt_idx = 0
|
|
||||||
ckpt_region = []
|
|
||||||
if isinstance(op, ForwardCheck):
|
|
||||||
ckpt_region.append(op.index)
|
|
||||||
|
|
||||||
# postprocess, make sure every activation checkpoint label in the
|
|
||||||
# same activation checkpoint region (level = 0) has the same length
|
|
||||||
op_list = []
|
|
||||||
for node in node_list:
|
|
||||||
op_list += node
|
|
||||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
|
||||||
for (start_idx, end_idx) in ckpt_regions:
|
|
||||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
|
||||||
for idx in range(start_idx, end_idx + 1):
|
|
||||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
|
||||||
|
|
||||||
|
|
||||||
def solver_rotor(gm: ColoGraphModule,
|
|
||||||
data,
|
|
||||||
mem_limit: int,
|
|
||||||
mem_slots: int = 500,
|
|
||||||
cnode: List[str] = None,
|
|
||||||
eps: float = 0.0,
|
|
||||||
force_python: bool = False) -> ColoGraphModule:
|
|
||||||
"""solver that automatically find activation checkpoint in rotor's manner
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
|
|
||||||
data (torch.Tensor): input data.
|
|
||||||
mem_limit (int): memory budget in Byte.
|
|
||||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
|
||||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
|
||||||
eps (float): epsilon for memory decay. Defaults to 0.0
|
|
||||||
force_python (bool): force to use python version of dynamic programs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
|
||||||
"""
|
|
||||||
|
|
||||||
# try to import C version solver if force_python is not set
|
|
||||||
logger = get_dist_logger()
|
|
||||||
if not force_python:
|
|
||||||
try:
|
|
||||||
from .dynamic_programs_C_version import persistent_compute_table
|
|
||||||
CVERSION = True
|
|
||||||
|
|
||||||
# build module if module not found
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
result = subprocess.Popen(
|
|
||||||
[
|
|
||||||
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
|
|
||||||
f"--build-lib={this_dir}"
|
|
||||||
],
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
if result.wait() == 0:
|
|
||||||
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
|
|
||||||
from .dynamic_programs_C_version import persistent_compute_table
|
|
||||||
CVERSION = True
|
|
||||||
else:
|
|
||||||
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
|
|
||||||
CVERSION = False
|
|
||||||
else:
|
|
||||||
CVERSION = False
|
|
||||||
|
|
||||||
# check if metainfoprop is done
|
|
||||||
if any(len(node.meta) == 0 for node in gm.graph.nodes):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
|
|
||||||
|
|
||||||
# linearize the graph
|
|
||||||
node_list = linearize(gm, cnode)
|
|
||||||
|
|
||||||
# construct chain
|
|
||||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
|
||||||
chain: Chain = _construct_chain(node_list, data)
|
|
||||||
chain._discretize(mem_unit)
|
|
||||||
|
|
||||||
# use C version if possible
|
|
||||||
if CVERSION and not force_python:
|
|
||||||
logger.info("Using C version rotor solver!", ranks=[0])
|
|
||||||
opt_table = persistent_compute_table(chain, mem_slots)
|
|
||||||
else:
|
|
||||||
opt_table = _compute_table(chain, mem_slots)
|
|
||||||
logger.info("Using python version rotor solver!", ranks=[0])
|
|
||||||
|
|
||||||
# found sequence
|
|
||||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
|
||||||
|
|
||||||
# if solver failed, we don't need to annotate the graph
|
|
||||||
if not SOLVER_FAILED:
|
|
||||||
_annotate_from_sequence(sequence, node_list)
|
|
||||||
|
|
||||||
# set __sequence__ attribute to GraphModule
|
|
||||||
if SOLVER_FAILED:
|
|
||||||
setattr(gm, "__sequence__", None)
|
|
||||||
else:
|
|
||||||
setattr(gm, "__sequence__", sequence)
|
|
||||||
|
|
||||||
# set __opttable__ attribute to GraphModule
|
|
||||||
setattr(gm, "__opttable__", opt_table[0])
|
|
||||||
gm.recompile()
|
|
||||||
return gm
|
|
|
@ -1,516 +0,0 @@
|
||||||
#define PY_SSIZE_T_CLEAN
|
|
||||||
#include <Python.h>
|
|
||||||
|
|
||||||
long* PySequenceToLongArray(PyObject* pylist) {
|
|
||||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
|
||||||
Py_ssize_t len = PySequence_Size(pylist);
|
|
||||||
long* result = (long*)calloc(len + 1, sizeof(long));
|
|
||||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
|
||||||
PyObject* item = PySequence_GetItem(pylist, i);
|
|
||||||
result[i] = PyLong_AsLong(item);
|
|
||||||
Py_DECREF(item);
|
|
||||||
}
|
|
||||||
result[len] = 0;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
double* PySequenceToDoubleArray(PyObject* pylist) {
|
|
||||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
|
||||||
Py_ssize_t len = PySequence_Size(pylist);
|
|
||||||
double* result = (double*)calloc(len + 1, sizeof(double));
|
|
||||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
|
||||||
PyObject* item = PySequence_GetItem(pylist, i);
|
|
||||||
result[i] = PyFloat_AsDouble(item);
|
|
||||||
Py_DECREF(item);
|
|
||||||
}
|
|
||||||
result[len] = 0;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
long* getLongArray(PyObject* container, const char* attributeName) {
|
|
||||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
|
||||||
long* result = PySequenceToLongArray(sequence);
|
|
||||||
Py_DECREF(sequence);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
|
||||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
|
||||||
double* result = PySequenceToDoubleArray(sequence);
|
|
||||||
Py_DECREF(sequence);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
|
||||||
PyObject* chain_param;
|
|
||||||
int mmax;
|
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
|
||||||
|
|
||||||
double* fw = getDoubleArray(chain_param, "fweight");
|
|
||||||
if (!fw) return NULL;
|
|
||||||
|
|
||||||
double* bw = getDoubleArray(chain_param, "bweight");
|
|
||||||
if (!bw) return NULL;
|
|
||||||
|
|
||||||
long* cw = getLongArray(chain_param, "cweight");
|
|
||||||
if (!cw) return NULL;
|
|
||||||
|
|
||||||
long* cbw = getLongArray(chain_param, "cbweight");
|
|
||||||
if (!cbw) return NULL;
|
|
||||||
|
|
||||||
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
|
|
||||||
if (!cbw) return NULL;
|
|
||||||
|
|
||||||
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
|
|
||||||
if (!cbw) return NULL;
|
|
||||||
|
|
||||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
|
||||||
if (!chain_length_param) return NULL;
|
|
||||||
long chain_length = PyLong_AsLong(chain_length_param);
|
|
||||||
Py_DECREF(chain_length_param);
|
|
||||||
|
|
||||||
// TODO: Can be optimized by only allocating memory for l >= i
|
|
||||||
// TODO: float / int instead of double / long ?
|
|
||||||
#define OPT(m, i, l) \
|
|
||||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
|
||||||
(i) * (chain_length + 1) + (l)]
|
|
||||||
double* opt = (double*)calloc(
|
|
||||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
|
||||||
|
|
||||||
#define WHAT(m, i, l) \
|
|
||||||
what[(m) * (chain_length + 1) * (chain_length + 1) + \
|
|
||||||
(i) * (chain_length + 1) + (l)]
|
|
||||||
long* what = (long*)calloc(
|
|
||||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long i = 0; i <= chain_length; ++i)
|
|
||||||
// TODO: Can be optimized to remove the IF by reordering loops
|
|
||||||
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
|
||||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
|
||||||
OPT(m, i, i) = fw[i] + bw[i];
|
|
||||||
else
|
|
||||||
OPT(m, i, i) = INFINITY;
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long d = 1; d <= chain_length; ++d) {
|
|
||||||
for (long i = 0; i <= chain_length - d; ++i) {
|
|
||||||
long idx = i + d;
|
|
||||||
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
|
|
||||||
if (idx > i + 1) {
|
|
||||||
long maxCostFWD = 0;
|
|
||||||
for (long j = i + 1; j < idx; j++) {
|
|
||||||
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
|
|
||||||
}
|
|
||||||
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
|
|
||||||
}
|
|
||||||
if ((m >= mmin)) {
|
|
||||||
long bestLeaf = -1;
|
|
||||||
double sumFw = 0;
|
|
||||||
double bestLeafCost = INFINITY;
|
|
||||||
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
|
|
||||||
/// i+1
|
|
||||||
for (long j = i + 1; j <= idx; ++j) {
|
|
||||||
sumFw += fw[j - 1];
|
|
||||||
if (m >= cw[j]) {
|
|
||||||
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
|
|
||||||
if (cost < bestLeafCost) {
|
|
||||||
bestLeafCost = cost;
|
|
||||||
bestLeaf = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
double chainCost = INFINITY;
|
|
||||||
if (m >= cbw[i + 1])
|
|
||||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
|
|
||||||
if (bestLeafCost <= chainCost) {
|
|
||||||
OPT(m, i, idx) = bestLeafCost;
|
|
||||||
WHAT(m, i, idx) = bestLeaf;
|
|
||||||
} else {
|
|
||||||
OPT(m, i, idx) = chainCost;
|
|
||||||
WHAT(m, i, idx) = -1;
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
OPT(m, i, idx) = INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free(fw);
|
|
||||||
free(bw);
|
|
||||||
free(cw);
|
|
||||||
free(cbw);
|
|
||||||
free(fwd_tmp);
|
|
||||||
free(bwd_tmp);
|
|
||||||
|
|
||||||
PyObject* res_opt = PyList_New(mmax + 1);
|
|
||||||
PyObject* res_what = PyList_New(mmax + 1);
|
|
||||||
|
|
||||||
// Convert the result into Python world
|
|
||||||
for (long m = 0; m <= mmax; ++m) {
|
|
||||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
|
||||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
|
||||||
PyObject* res_what_m = PyList_New(chain_length + 1);
|
|
||||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
|
||||||
for (long i = 0; i <= chain_length; ++i) {
|
|
||||||
PyObject* res_opt_m_i = PyDict_New();
|
|
||||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
|
||||||
PyObject* res_what_m_i = PyDict_New();
|
|
||||||
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
|
|
||||||
for (long l = i; l <= chain_length; ++l) {
|
|
||||||
PyObject* res_l = PyLong_FromLong(l);
|
|
||||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
|
||||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
|
||||||
Py_DECREF(res_opt_m_i_l);
|
|
||||||
PyObject* res_what_m_i_l;
|
|
||||||
long what_m_i_l = WHAT(m, i, l);
|
|
||||||
if (what_m_i_l < 0)
|
|
||||||
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
|
|
||||||
else
|
|
||||||
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
|
|
||||||
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
|
|
||||||
Py_DECREF(res_what_m_i_l);
|
|
||||||
Py_DECREF(res_l);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free(opt);
|
|
||||||
free(what);
|
|
||||||
|
|
||||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
|
||||||
Py_DECREF(res_opt);
|
|
||||||
Py_DECREF(res_what);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// long i = L - s, j = t - s, k = l - t
|
|
||||||
inline long floating_index_in_array(long m_factor, long m, long i, long j,
|
|
||||||
long k) {
|
|
||||||
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
|
|
||||||
(j * (j - 1)) / 2 + k;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
long sp;
|
|
||||||
long r;
|
|
||||||
long tp;
|
|
||||||
} index_t;
|
|
||||||
|
|
||||||
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
|
|
||||||
PyObject* chain_param;
|
|
||||||
int mmax;
|
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
|
||||||
|
|
||||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
|
||||||
if (!fw) return NULL;
|
|
||||||
|
|
||||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
|
||||||
if (!bw) return NULL;
|
|
||||||
|
|
||||||
long* cw = getLongArray(chain_param, "cweigth");
|
|
||||||
if (!cw) return NULL;
|
|
||||||
|
|
||||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
|
||||||
if (!cbw) return NULL;
|
|
||||||
|
|
||||||
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
|
|
||||||
if (!fwd_tmp) return NULL;
|
|
||||||
|
|
||||||
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
|
|
||||||
if (!bwd_tmp) return NULL;
|
|
||||||
|
|
||||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
|
||||||
if (!chain_length_param) return NULL;
|
|
||||||
long chain_length = PyLong_AsLong(chain_length_param);
|
|
||||||
Py_DECREF(chain_length_param);
|
|
||||||
|
|
||||||
const long m_factor =
|
|
||||||
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
|
|
||||||
|
|
||||||
// Defined for 0 <= s <= t <= l <= chain_length, for all m
|
|
||||||
#undef OPT
|
|
||||||
#define OPT(m, s, t, l) \
|
|
||||||
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
|
||||||
(l) - (t))]
|
|
||||||
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
|
|
||||||
|
|
||||||
#undef WHAT
|
|
||||||
#define WHAT(m, s, t, l) \
|
|
||||||
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
|
||||||
(l) - (t))]
|
|
||||||
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
|
|
||||||
|
|
||||||
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
|
|
||||||
double total = 0;
|
|
||||||
for (long i = 0; i < chain_length; ++i) {
|
|
||||||
partialSumsFW[i] = total;
|
|
||||||
total += fw[i];
|
|
||||||
}
|
|
||||||
partialSumsFW[chain_length] = total;
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long i = 0; i <= chain_length; ++i) {
|
|
||||||
// TODO: Can be optimized to remove the IF by reordering loops
|
|
||||||
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
|
||||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
|
||||||
OPT(m, i, i, i) = fw[i] + bw[i];
|
|
||||||
else
|
|
||||||
OPT(m, i, i, i) = INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long d = 1; d <= chain_length; ++d) { // d = l - s
|
|
||||||
for (long s = 0; s <= chain_length - d; ++s) {
|
|
||||||
long l = s + d;
|
|
||||||
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
|
|
||||||
long memNullSecond = 0;
|
|
||||||
for (long j = s + 1; j < l; ++j) {
|
|
||||||
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
|
|
||||||
if (val > memNullSecond) memNullSecond = val;
|
|
||||||
}
|
|
||||||
for (long t = s; t <= l; ++t) {
|
|
||||||
double chainCost = INFINITY;
|
|
||||||
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
|
|
||||||
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
|
|
||||||
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
|
|
||||||
}
|
|
||||||
double bestLeafCost = INFINITY;
|
|
||||||
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
|
|
||||||
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
|
|
||||||
for (long r = s; r <= t; ++r)
|
|
||||||
if (cw[s] <= cw[r])
|
|
||||||
for (long tp = t + 1; tp <= l; ++tp)
|
|
||||||
for (long sp = r + 1; sp <= tp; ++sp) {
|
|
||||||
long mp = m - cw[r] + cw[s];
|
|
||||||
assert(mp >= 0);
|
|
||||||
if (mp >= cw[sp]) {
|
|
||||||
double value = partialSumsFW[sp] - partialSumsFW[s] +
|
|
||||||
OPT(mp - cw[sp], sp, tp, l) +
|
|
||||||
OPT(mp, r, t, tp - 1);
|
|
||||||
if (value < bestLeafCost) {
|
|
||||||
bestLeafCost = value;
|
|
||||||
bestLeaf.sp = sp;
|
|
||||||
bestLeaf.r = r;
|
|
||||||
bestLeaf.tp = tp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
|
|
||||||
OPT(m, s, t, l) = bestLeafCost;
|
|
||||||
WHAT(m, s, t, l).sp = bestLeaf.sp;
|
|
||||||
WHAT(m, s, t, l).r = bestLeaf.r;
|
|
||||||
WHAT(m, s, t, l).tp = bestLeaf.tp;
|
|
||||||
} else {
|
|
||||||
OPT(m, s, t, l) = chainCost;
|
|
||||||
WHAT(m, s, t, l).sp = -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free(fw);
|
|
||||||
free(bw);
|
|
||||||
free(cw);
|
|
||||||
free(cbw);
|
|
||||||
free(fwd_tmp);
|
|
||||||
free(bwd_tmp);
|
|
||||||
|
|
||||||
PyObject* res_opt = PyList_New(mmax + 1);
|
|
||||||
PyObject* res_what = PyList_New(mmax + 1);
|
|
||||||
|
|
||||||
// Convert the result into Python world
|
|
||||||
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
|
|
||||||
for (long m = 0; m <= mmax; ++m) {
|
|
||||||
PyObject* res_opt_m = PyDict_New();
|
|
||||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
|
||||||
PyObject* res_what_m = PyDict_New();
|
|
||||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
|
||||||
for (long s = 0; s <= chain_length; ++s)
|
|
||||||
for (long t = s; t <= chain_length; ++t)
|
|
||||||
for (long l = t; l <= chain_length; ++l) {
|
|
||||||
PyObject* key = Py_BuildValue("(lll)", s, t, l);
|
|
||||||
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
|
|
||||||
PyDict_SetItem(res_opt_m, key, value_opt);
|
|
||||||
PyObject* value_what = true_tuple;
|
|
||||||
index_t* idx_what = &WHAT(m, s, t, l);
|
|
||||||
if (idx_what->sp >= 0)
|
|
||||||
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
|
|
||||||
idx_what->r, idx_what->tp);
|
|
||||||
PyDict_SetItem(res_what_m, key, value_what);
|
|
||||||
if (value_what != true_tuple) Py_DECREF(value_what);
|
|
||||||
Py_DECREF(key);
|
|
||||||
Py_DECREF(value_opt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Py_DECREF(true_tuple);
|
|
||||||
|
|
||||||
free(opt);
|
|
||||||
free(what);
|
|
||||||
|
|
||||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
|
||||||
Py_DECREF(res_opt);
|
|
||||||
Py_DECREF(res_what);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
|
|
||||||
PyObject* args) {
|
|
||||||
PyObject* chain_param;
|
|
||||||
int mmax;
|
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
|
||||||
|
|
||||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
|
||||||
if (!fw) return NULL;
|
|
||||||
|
|
||||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
|
||||||
if (!bw) return NULL;
|
|
||||||
|
|
||||||
long* cw = getLongArray(chain_param, "cweigth");
|
|
||||||
if (!cw) return NULL;
|
|
||||||
|
|
||||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
|
||||||
if (!cbw) return NULL;
|
|
||||||
|
|
||||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
|
||||||
if (!chain_length_param) return NULL;
|
|
||||||
long chain_length = PyLong_AsLong(chain_length_param);
|
|
||||||
Py_DECREF(chain_length_param);
|
|
||||||
|
|
||||||
// TODO: Can be optimized by only allocating memory for l >= i
|
|
||||||
// TODO: float / int instead of double / long ?
|
|
||||||
#undef OPT
|
|
||||||
#define OPT(m, i, l) \
|
|
||||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
|
||||||
(i) * (chain_length + 1) + (l)]
|
|
||||||
double* opt = (double*)calloc(
|
|
||||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
|
||||||
|
|
||||||
// Compute partial sums
|
|
||||||
double* sumfw = (double*)calloc(chain_length, sizeof(double));
|
|
||||||
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
|
|
||||||
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
|
|
||||||
|
|
||||||
double total = 0;
|
|
||||||
for (long i = 0; i < chain_length; ++i) {
|
|
||||||
total += fw[i];
|
|
||||||
sumfw[i] = total;
|
|
||||||
}
|
|
||||||
|
|
||||||
total = 0;
|
|
||||||
for (long i = 0; i < chain_length + 1; ++i) {
|
|
||||||
total += bw[i];
|
|
||||||
sumbw[i] = total;
|
|
||||||
}
|
|
||||||
|
|
||||||
total = 0;
|
|
||||||
for (long i = 0; i < chain_length; ++i) {
|
|
||||||
total += sumfw[i];
|
|
||||||
sumsumfw[i] = total;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long i = 0; i <= chain_length; ++i) {
|
|
||||||
// TODO: Can be optimized to remove the IF by reordering loops
|
|
||||||
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
|
|
||||||
OPT(m, i, i) = bw[i];
|
|
||||||
else
|
|
||||||
OPT(m, i, i) = INFINITY;
|
|
||||||
|
|
||||||
if (i < chain_length) {
|
|
||||||
long maxC = fmaxl(cw[i], cw[i + 1]);
|
|
||||||
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
|
|
||||||
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
|
|
||||||
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
|
|
||||||
else
|
|
||||||
OPT(m, i, i + 1) = INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
|
||||||
for (long i = 0; i + 2 <= chain_length; ++i) {
|
|
||||||
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
|
|
||||||
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
|
|
||||||
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
|
|
||||||
for (long l = i + 2; l <= chain_length; ++l) {
|
|
||||||
maxCW_il = fmax(maxCW_il, cw[l + 1]);
|
|
||||||
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
|
|
||||||
long mmin = fmaxl(mminCst, maxCostFWD);
|
|
||||||
if ((m >= mmin)) {
|
|
||||||
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
|
|
||||||
noCheckpointCost +=
|
|
||||||
sumsumfw[l - 1] -
|
|
||||||
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
|
|
||||||
|
|
||||||
double valueCost = INFINITY;
|
|
||||||
if (m >= cw[i]) {
|
|
||||||
double sumFwds = 0;
|
|
||||||
for (long j = i + 1; j < l; ++j) {
|
|
||||||
sumFwds += fw[j - 1];
|
|
||||||
valueCost = fmin(
|
|
||||||
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
|
|
||||||
} else
|
|
||||||
OPT(m, i, l) = INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free(sumfw);
|
|
||||||
free(sumbw);
|
|
||||||
free(sumsumfw);
|
|
||||||
free(fw);
|
|
||||||
free(bw);
|
|
||||||
free(cw);
|
|
||||||
free(cbw);
|
|
||||||
|
|
||||||
PyObject* res_opt = PyList_New(mmax + 1);
|
|
||||||
|
|
||||||
// Convert the result into Python world
|
|
||||||
for (long m = 0; m <= mmax; ++m) {
|
|
||||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
|
||||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
|
||||||
for (long i = 0; i <= chain_length; ++i) {
|
|
||||||
PyObject* res_opt_m_i = PyDict_New();
|
|
||||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
|
||||||
for (long l = i; l <= chain_length; ++l) {
|
|
||||||
PyObject* res_l = PyLong_FromLong(l - i);
|
|
||||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
|
||||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
|
||||||
Py_DECREF(res_opt_m_i_l);
|
|
||||||
Py_DECREF(res_l);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free(opt);
|
|
||||||
|
|
||||||
return res_opt;
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyMethodDef dynamic_programs_methods[] = {
|
|
||||||
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
|
|
||||||
"Compute the optimal table with the persistent algorithm."},
|
|
||||||
{"floating_compute_table", floating_compute_table, METH_VARARGS,
|
|
||||||
"Compute the optimal table with the floating algorithm."},
|
|
||||||
{"griewank_heterogeneous_compute_table",
|
|
||||||
griewank_heterogeneous_compute_table, METH_VARARGS,
|
|
||||||
"Compute the optimal table for the Griewank Heterogeneous Model."},
|
|
||||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
|
||||||
};
|
|
||||||
|
|
||||||
static struct PyModuleDef dynamic_programs_module = {
|
|
||||||
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
|
|
||||||
NULL, /* module documentation, may be NULL */
|
|
||||||
-1, /* size of per-interpreter state of the module,
|
|
||||||
or -1 if the module keeps state in global variables. */
|
|
||||||
dynamic_programs_methods};
|
|
||||||
|
|
||||||
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
|
|
||||||
return PyModule_Create(&dynamic_programs_module);
|
|
||||||
}
|
|
|
@ -1,94 +0,0 @@
|
||||||
from typing import List, Any
|
|
||||||
from torch.fx import GraphModule, Node
|
|
||||||
from colossalai.fx.profiler import is_inplace
|
|
||||||
|
|
||||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
|
||||||
# unchanged throughout the whole model, it will be used several times by
|
|
||||||
# different blocks of model, so that it is hard for us to linearize the graph
|
|
||||||
# when we encounter those kinds of nodes. We let users to annotate some of the
|
|
||||||
# input as common node, such as attention mask, and the followings are some of
|
|
||||||
# the ops that could actually be seen as common nodes. With our common node prop,
|
|
||||||
# we could find some of the "real" common nodes (e.g. the real attention mask
|
|
||||||
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
|
||||||
# nodes or it's op belongs to the following operations, we view this node as a
|
|
||||||
# newly born common node.
|
|
||||||
# List of target name that could be seen as common node
|
|
||||||
COPS = ["getattr", "getitem", "size"]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_cop(target: Any) -> bool:
|
|
||||||
"""Check if an op could be seen as common node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target (Any): node target
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(target, str):
|
|
||||||
return target in COPS
|
|
||||||
else:
|
|
||||||
return target.__name__ in COPS
|
|
||||||
|
|
||||||
|
|
||||||
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
|
||||||
"""Linearizing the graph
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gm (GraphModule): GraphModule derived by tracing
|
|
||||||
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[Node]]: List of list, each inside list of Node presents
|
|
||||||
the actual 'node' in linearized manner.
|
|
||||||
|
|
||||||
Remarks:
|
|
||||||
We merge the inplace ops into the previous node.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_sink() -> bool:
|
|
||||||
"""Check if we can free all dependencies
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
|
||||||
|
|
||||||
# make sure that item in cnode is valid
|
|
||||||
if cnode:
|
|
||||||
for name in cnode:
|
|
||||||
try:
|
|
||||||
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
|
|
||||||
f"common node {name} is not an input of the model"
|
|
||||||
except StopIteration:
|
|
||||||
raise ValueError(f"common node name {name} not in graph")
|
|
||||||
|
|
||||||
else:
|
|
||||||
cnode = []
|
|
||||||
|
|
||||||
deps = {}
|
|
||||||
linearized_nodes = []
|
|
||||||
region = []
|
|
||||||
|
|
||||||
for n in gm.graph.nodes:
|
|
||||||
if n.op != "placeholder" and n.op != "output":
|
|
||||||
for n_par in n._input_nodes:
|
|
||||||
if n_par.op != "placeholder" and n_par.name not in cnode:
|
|
||||||
deps[n_par] -= 1
|
|
||||||
region.append(n)
|
|
||||||
|
|
||||||
# if the node could free all dependencies in graph
|
|
||||||
# we could begin a new node
|
|
||||||
if _is_sink():
|
|
||||||
linearized_nodes.append(region)
|
|
||||||
region = []
|
|
||||||
|
|
||||||
# propagate common node attr if possible
|
|
||||||
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
|
|
||||||
cnode.append(n.name)
|
|
||||||
else:
|
|
||||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
|
||||||
|
|
||||||
return linearized_nodes
|
|
|
@ -1,270 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def _discretize(mem_unit, values):
|
|
||||||
return [math.ceil(value / mem_unit) for value in values]
|
|
||||||
|
|
||||||
|
|
||||||
class Chain:
|
|
||||||
|
|
||||||
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
|
|
||||||
self.fweight = fw
|
|
||||||
self.bweight = bw
|
|
||||||
self.cweight = cw
|
|
||||||
self.cbweight = cbw
|
|
||||||
self.fwd_mem_tmp = ftmp
|
|
||||||
self.bwd_mem_tmp = btmp
|
|
||||||
self.length = len(fw)
|
|
||||||
if check and not self.check_lengths():
|
|
||||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
|
||||||
|
|
||||||
def check_lengths(self):
|
|
||||||
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
|
|
||||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
|
|
||||||
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
chain_list = []
|
|
||||||
for i in range(self.length):
|
|
||||||
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
|
|
||||||
self.bwd_mem_tmp[i]))
|
|
||||||
i = self.length
|
|
||||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
|
|
||||||
return chain_list.__repr__()
|
|
||||||
|
|
||||||
def _discretize(self, mem_unit):
|
|
||||||
self.cweight = _discretize(mem_unit, self.cweight)
|
|
||||||
self.cbweight = _discretize(mem_unit, self.cbweight)
|
|
||||||
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
|
|
||||||
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
|
|
||||||
|
|
||||||
|
|
||||||
class Operation:
|
|
||||||
|
|
||||||
def shift(self, value):
|
|
||||||
if type(self.index) is tuple:
|
|
||||||
self.index = tuple(x + value for x in self.index)
|
|
||||||
else:
|
|
||||||
self.index += value
|
|
||||||
|
|
||||||
|
|
||||||
class Offload(Operation):
|
|
||||||
|
|
||||||
def __init__(self, index, has_bar=False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.index = index
|
|
||||||
self.name = "Off"
|
|
||||||
self.has_bar = has_bar
|
|
||||||
if self.has_bar:
|
|
||||||
self.name += "wBar"
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.name}_{self.index}"
|
|
||||||
|
|
||||||
|
|
||||||
class Prefetch(Operation):
|
|
||||||
|
|
||||||
def __init__(self, index, has_bar=False) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.index = index
|
|
||||||
self.name = "Pre"
|
|
||||||
self.has_bar = has_bar
|
|
||||||
if self.has_bar:
|
|
||||||
self.name += "wBar"
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.name}_{self.index}"
|
|
||||||
|
|
||||||
|
|
||||||
class Forward(Operation):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
self.index = index
|
|
||||||
self.name = "F"
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
|
||||||
|
|
||||||
def cost(self, chain: Chain):
|
|
||||||
if chain is not None:
|
|
||||||
return chain.fweight[self.index]
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardEnable(Forward):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "Fe"
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardNograd(Forward):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "Fn"
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardCheck(Forward):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "CF"
|
|
||||||
|
|
||||||
|
|
||||||
class Forwards(Operation):
|
|
||||||
|
|
||||||
def __init__(self, start, end):
|
|
||||||
self.index = (start, end)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
|
||||||
|
|
||||||
def cost(self, chain: Chain):
|
|
||||||
if chain is not None:
|
|
||||||
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
|
|
||||||
else:
|
|
||||||
return (self.index[1] - self.index[0] + 1)
|
|
||||||
|
|
||||||
|
|
||||||
def isForward(op):
|
|
||||||
return type(op) is Forward or type(op) is Forwards
|
|
||||||
|
|
||||||
|
|
||||||
class Backward(Operation):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "B_{i}".format(i=self.index)
|
|
||||||
|
|
||||||
def cost(self, chain: Chain):
|
|
||||||
if chain is not None:
|
|
||||||
return chain.bweight[self.index]
|
|
||||||
else:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class Loss(Operation):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "L"
|
|
||||||
|
|
||||||
def cost(self, chain):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryAccess(Operation):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
|
||||||
|
|
||||||
def cost(self, chain: Chain):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class WriteMemory(MemoryAccess):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "WM"
|
|
||||||
|
|
||||||
|
|
||||||
class ReadMemory(MemoryAccess):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "RM"
|
|
||||||
|
|
||||||
|
|
||||||
class DiscardMemory(MemoryAccess):
|
|
||||||
|
|
||||||
def __init__(self, index):
|
|
||||||
super().__init__(index)
|
|
||||||
self.name = "DM"
|
|
||||||
|
|
||||||
|
|
||||||
class Function:
|
|
||||||
|
|
||||||
def __init__(self, name, *args):
|
|
||||||
self.name = name
|
|
||||||
self.args = args
|
|
||||||
self.str_args = ','.join(str(v) for v in self.args)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{n}({args})".format(n=self.name, args=self.str_args)
|
|
||||||
|
|
||||||
|
|
||||||
class Sequence:
|
|
||||||
|
|
||||||
def __init__(self, function):
|
|
||||||
self.sequence = [] #List of Operation and Sequence
|
|
||||||
self.function = function #Description the function (name and parameters)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return repr(self.list_operations())
|
|
||||||
|
|
||||||
def list_operations(self):
|
|
||||||
op_list = []
|
|
||||||
for x in self.sequence:
|
|
||||||
if isinstance(x, Operation):
|
|
||||||
op_list.append(x)
|
|
||||||
else:
|
|
||||||
assert isinstance(x, Sequence)
|
|
||||||
op_list += x.list_operations()
|
|
||||||
return op_list
|
|
||||||
|
|
||||||
def insert(self, operation):
|
|
||||||
self.sequence.append(operation)
|
|
||||||
|
|
||||||
def remove(self, operation_index):
|
|
||||||
del self.sequence[operation_index]
|
|
||||||
|
|
||||||
def insert_sequence(self, sequence):
|
|
||||||
self.sequence.append(sequence)
|
|
||||||
|
|
||||||
def shift(self, value):
|
|
||||||
for x in self.sequence:
|
|
||||||
x.shift(value)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def remove_useless_write(self):
|
|
||||||
if self.sequence:
|
|
||||||
if isinstance(self.sequence[0], WriteMemory):
|
|
||||||
self.remove(0)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_makespan(self, chain):
|
|
||||||
return sum(op.cost(chain) for op in self.list_operations())
|
|
||||||
|
|
||||||
def without_suffix(self):
|
|
||||||
ops = self.list_operations()
|
|
||||||
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
|
|
||||||
try:
|
|
||||||
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
|
|
||||||
except ValueError:
|
|
||||||
last_idx = -1
|
|
||||||
if last_idx == end_of_first_phase - 1:
|
|
||||||
return (self, None)
|
|
||||||
chain_length = ops[end_of_first_phase -
|
|
||||||
1].index ## Some assumption here about the sequence (finishes with Forward_L
|
|
||||||
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
|
|
||||||
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
|
|
||||||
for i in range(last_idx + 1):
|
|
||||||
result.insert(ops[i])
|
|
||||||
result.insert(Loss())
|
|
||||||
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
|
|
||||||
position = end_of_first_phase + 1 + (chain_length - i)
|
|
||||||
assert type(ops[position]) is Backward
|
|
||||||
assert ops[position].index == i
|
|
||||||
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
|
|
||||||
result.insert(ops[i])
|
|
||||||
return (result, start_of_fwd_enable_chain)
|
|
Loading…
Reference in New Issue