diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py deleted file mode 100644 index 9ccf135d0..000000000 --- a/colossalai/fx/passes/algorithms/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .ckpt_solver_chen import chen_greedy -from .linearize import linearize -from .ckpt_solver_rotor import solver_rotor -from .ckpt_solver_pofo import solver_pofo diff --git a/colossalai/fx/passes/algorithms/build_c_ext.py b/colossalai/fx/passes/algorithms/build_c_ext.py deleted file mode 100644 index cb360cb20..000000000 --- a/colossalai/fx/passes/algorithms/build_c_ext.py +++ /dev/null @@ -1,15 +0,0 @@ -from setuptools import setup, Extension -import os - -this_dir = os.path.dirname(os.path.abspath(__file__)) -ext_modules = [Extension( - 'dynamic_programs_C_version', - sources=[os.path.join(this_dir, 'dynamic_programs.c')], -)] - -setup( - name='rotor c extension', - version='0.1', - description='rotor c extension for faster dp computing', - ext_modules=ext_modules, -) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py deleted file mode 100644 index 52000ebe5..000000000 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ /dev/null @@ -1,98 +0,0 @@ -import math -from typing import List, Set, Tuple - -import torch -from torch.fx import GraphModule, Node - -from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp - -__all__ = ['chen_greedy'] -CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] - - -def _all_potential_ckpt_nodes(gm: GraphModule) -> List: - """ - In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized. - """ - - def is_sink(): - """ - If we can free all memories when executing a certain node, it is a sink. - """ - return not sum((v for k, v in deps.items())) - - deps = {} - ckpt_nodes = [] - for n in gm.graph.nodes: - for n_par in n._input_nodes: - deps[n_par] -= 1 # free memory and dependencies - - # We can only put act_ckpt on these nodes - if n.op in CKPT_OP and is_sink(): - ckpt_nodes.append(n) - deps[n] = len(n.users) # add dependencies for future executions - return ckpt_nodes - - -def chen_greedy(gm: GraphModule) -> GraphModule: - """ - This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. - Note that this algorithm targets at memory optimization only, using techniques in appendix A. - - Usage: - model = resnet18() - input_sample = torch.rand(4, 3, 224, 224) - gm = symbolic_trace(model) - MetaInfoProp(gm).run(input_sample) - gm = chen_greedy(gm) - - Args: - gm (GraphModule): The module to add checkpoints - """ - - def grid_search(num_grids: int = 6) -> Set: - """ - Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. - Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. - """ - _, b_approx = run_chen_greedy(0) - b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) - b_opt = math.inf - for b in range(b_min, b_max, (b_max - b_min) // num_grids): - ckpt_intv, b_approx = run_chen_greedy(b) - if b_approx < b_opt: - b_opt = b_approx - ckpt_opt = ckpt_intv - return ckpt_opt - - def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: - """ - This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. - """ - ckpt_nodes = _all_potential_ckpt_nodes(gm) - ckpt_intv = [] - temp = 0 - x = 0 - y = 0 - prev_idx = 2 - for (idx, n) in enumerate(gm.graph.nodes): - n: Node - temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) - y = max(y, temp) - if temp > b and n in ckpt_nodes: - x += calculate_fwd_in(n) - temp = 0 - ckpt_intv.append((prev_idx, idx + 1)) - prev_idx = idx + 1 - return ckpt_intv, math.floor(math.sqrt(x * y)) - - gm.graph.lint() # make sure nodes are in topological order - ckpt = grid_search(num_grids=6) - node_list = list(gm.graph.nodes) - for i, seg in enumerate(ckpt): - for idx in range(*seg): - n = node_list[idx] - if n.op in CKPT_OP: - setattr(n, 'activation_checkpoint', i) - gm.recompile() - return gm diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py deleted file mode 100644 index 69e4e9f2c..000000000 --- a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py +++ /dev/null @@ -1,537 +0,0 @@ -import copy -import math -from typing import List, Tuple - -import torch -from colossalai.fx import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import \ - _find_nested_ckpt_regions -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec) -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import parameter_size -from torch.fx import GraphModule, Node - -from .linearize import linearize -from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch, - Sequence) - -INF = float("inf") - - -def _normalize_flops(chain: Chain, flops) -> Chain: - """ - Normalize flops - """ - for i in range(chain.length): - chain.fweight[i] /= flops - chain.bweight[i] /= flops - - return chain - - -class PofoTable: - """PofoTable - The PofoTable contains the necessary components to store intermediate results - of dynamic programming and the operations alone the way. - """ - - def __init__(self, chain_length: int, mem_slots: int): - """Init pofo table - The pofo table contains two tables, opt and what, indicating values and - operations. - - Args: - chain_length (int): chain length - mem_slots (int): number of memory slots - """ - - self.length = chain_length - self.mem_slots = mem_slots - - # initializing tables - # the first bool indicates whether the input has bar - # opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db) - # what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index) - # where is_enable indicates whether we enable the gradient, is_offload indicates whether we - # offload the input, index indicates the end of F_\empty sequence if is_enable = False - self.opt = { - False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], - True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] - } - self.what = { - False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], - True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] - } - - def _get_value(self, state, table, default): - i, act_size, df, db, input_has_bar = state - if act_size + df > self.mem_slots or act_size + db > self.mem_slots: - return default - - try: - return table[input_has_bar][i][act_size][(df, db)] - except KeyError: - print(f"state not found {state}") - - def get_opt(self, state): - return self._get_value(state, self.opt, INF) - - def get_what(self, state): - return self._get_value(state, self.what, INF) - - def set_value(self, state, opt, what): - i, act_size, df, db, input_has_bar = state - self.opt[input_has_bar][i][act_size][(df, db)] = opt - self.what[input_has_bar][i][act_size][(df, db)] = what - - -class PofoSolver: - """PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html - The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs - and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse - rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload. - """ - - def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None: - self.chain = chain - self.length = chain.length - self.max_memory = max_memory - self.mem_slots = mem_slots - self.mem_unit = max_memory / mem_slots - self.bandwidth = bandwidth - - self.disc_chain = copy.deepcopy(self.chain) - self.disc_chain._discretize(self.mem_unit) - - self.rotor_table = _compute_table(self.disc_chain, mem_slots) - self._compute_pofo_table() - - def _discretize(self, *values) -> Tuple: - return tuple(math.ceil(value / self.mem_unit) for value in values) - - def _undiscretize(self, *discrete_values) -> Tuple: - if len(discrete_values) == 1: - return discrete_values[0] * self.mem_unit - else: - return tuple(d * self.mem_unit for d in discrete_values) - - def _mmax_all(self, idx: int): - """ - Calculate the maximum memory usage of Fi_all - """ - - return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx] - - def _mmax_b(self, idx: int): - """ - Calculate the maximum memory usage of Bi - """ - - return self.chain.cbweight[idx + - 1] + self.chain.cweight[idx + - 1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx] - - def _mmax_ng(self, i: int, j: int): - """ - Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty - """ - - res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j] - if j > i: - res += self.chain.cweight[j] - return res - - def _rotor_estimated_bwd(self, i, j, m, delta): - compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j] - comm = delta / self.bandwidth - return (max(compute, comm) + compute + comm) / 2 - - def _rotor_estimated_bwd_sequence(self, i, j, m, delta): - return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table) - - def _common_values_enable(self, state: Tuple): - - idx, act_size, df, db, input_has_bar = state - input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] - mf = act_size + df + input_size - mb = act_size + db + input_size - mem_avail = self.max_memory - act_size - input_size - f_usage = self._mmax_all(idx) - b_usage = self._mmax_b(idx) - - # infeasible - if f_usage > mem_avail or b_usage > mem_avail: - return None - - # calculate idle time - eps_f_beta = max(0, f_usage - self.max_memory + mf) - eps_b_beta = max(0, b_usage - self.max_memory + mb) - idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth - - # calculate offload and prefetch data - offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta - prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta - - # total_time - total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time - - return (offload_data, prefetch_data, total_time, idle_time) - - def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False): - - i, act_size, df, db, input_has_bar = state - - # compute new epsilon_tmp and sum_fwds - if iterative: - self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds) - self.sum_fwds += self.chain.fweight[j] - else: - self.epsilon_tmp = max( - self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1)) - self.sum_fwds = sum(self.chain.fweight[i:j + 1]) - - input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i] - mf = act_size + df + input_size - mem_avail = self.max_memory - act_size - input_size - - # if infeasible - if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail: - return None - - eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf) - offload_data = self.sum_fwds * self.bandwidth + eps_f_beta - - # TODO: Implement the precise backward recompute sequence mentioned in the paper - # currently we will use an approximate way to get the backward time - time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db) - - prefetch_data = time_backward * self.bandwidth - idle_time = eps_f_beta / self.bandwidth - total_time = self.sum_fwds + idle_time + time_backward - - return (offload_data, prefetch_data, total_time, idle_time) - - def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple: - """Generate new values for next state - - Args: - state (Tuple): undiscretized states - do_offload (bool): bool type indicates whether we need to do offload - common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time) - - Returns: - Tuple: (new_act_size, new_df, new_db) - """ - idx, act_size, df, db, input_has_bar = state - offload_data, prefetch_data, *_ = common_values - input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] - if do_offload: - new_act_size = act_size - new_df = max(0, df + input_size - offload_data) - new_db = max(0, db - prefetch_data) + input_size - else: - new_act_size = act_size + input_size - new_df = max(0, df - offload_data) - new_db = max(0, db - prefetch_data) - - return (new_act_size, new_df, new_db) - - def _compute_pofo_table(self): - self.table = PofoTable(self.length, self.mem_slots) - - # initializing the loss - for act_size in range(self.mem_slots + 1): - for df in range(self.mem_slots - act_size + 1): - for db in range(self.mem_slots - act_size + 1): - # undiscretize for idle time calculation - origin_values = self._undiscretize(act_size, df, db) - - for input_has_bar in (False, True): - disc_state = (self.length, act_size, df, db, input_has_bar) - state = (self.length, *origin_values, input_has_bar) - common_values = self._common_values_enable(state) - - # if no feasible choice - if common_values is None: - self.table.set_value(disc_state, INF, None) - continue - - # if there is feasible choice - new_act_size, new_df, new_db = self._new_values(state, False, common_values) - eps_g = (new_df + new_db) / self.bandwidth - total_time = common_values[2] + eps_g - self.table.set_value(disc_state, total_time, (True, False)) - - # main loop - for i in reversed(range(self.length)): - for act_size in range(self.mem_slots + 1): - for df in range(self.mem_slots - act_size + 1): - for db in range(self.mem_slots - act_size + 1): - # undiscretize for idle time calculation - origin_values = self._undiscretize(act_size, df, db) - - for input_has_bar in (False, True): - best_result = INF - best_choice = None - disc_state = (i, act_size, df, db, input_has_bar) - state = (i, *origin_values, input_has_bar) - - # case 1: start with F_all - vals_enable = self._common_values_enable(state) - if vals_enable is not None: - for do_offload in (True, False): - new_state = self._new_values(state, do_offload, vals_enable) - new_state = (i + 1, *self._discretize(*new_state), True) - total_time = vals_enable[2] - results_all = self.table.get_opt(new_state) + total_time - if results_all < best_result: - best_result = results_all - best_choice = (True, do_offload) - - # case 2: start with F_ck - self.sum_fwds = 0 - self.epsilon_tmp = 0 - for j in range(i, self.length): - vals_nograd = self._common_values_nograd(state, j, True) - - # if infeasible - if vals_nograd is None: - continue - - for do_offload in (True, False): - new_state = self._new_values(state, do_offload, vals_nograd) - new_state = (j + 1, *self._discretize(*new_state), False) - total_time = vals_nograd[2] - result_nograd = total_time + self.table.get_opt(new_state) - if result_nograd < best_result: - best_result = result_nograd - best_choice = (False, do_offload, j) - - self.table.set_value(disc_state, best_result, best_choice) - - def pofo_rec(self, disc_state): - i, act_size, df, db, input_has_bar = disc_state - result = Sequence(Function("pofo", *disc_state)) - what = self.table.get_what(disc_state) - state = self._undiscretize(act_size, df, db) - state = (i, *state, input_has_bar) - i, act_size, df, db, input_has_bar = state - - if what is None: - return None - - # if loss - if i == self.length: - result.insert(Loss()) - return result - - if what[0]: - do_offload = what[1] - values = self._common_values_enable(state) - new_state = self._discretize(*self._new_values(state, do_offload, values)) - new_state = (i + 1, *new_state, True) - if do_offload: - result.insert(Offload(i, input_has_bar)) - result.insert(ForwardEnable(i)) - result.insert_sequence(self.pofo_rec(new_state)) - if do_offload: - result.insert(Prefetch(i, input_has_bar)) - result.insert(Backward(i)) - - else: - _, do_offload, j = what - values = self._common_values_nograd(state, j) - new_state = self._discretize(*self._new_values(state, do_offload, values)) - new_state = (j + 1, *new_state, False) - if do_offload: - result.insert(Offload(i, input_has_bar)) - result.insert(ForwardCheck(i)) - for k in range(i + 1, j + 1): - result.insert(ForwardNograd(k)) - result.insert_sequence(self.pofo_rec(new_state)) - if do_offload: - result.insert(Prefetch(i, input_has_bar)) - m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]) - - #TODO: Implement the precise backward recompute sequence mentioned in the paper - result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db)) - - return result - - -def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]): - op_list = sequence.list_operations() - loss_op = next(op for op in op_list if isinstance(op, Loss)) - fwd_list = op_list[:op_list.index(loss_op)] - bwd_list = op_list[op_list.index(loss_op) + 1:] - ckpt_idx = 0 - in_ckpt = False - ckpt_region = [] - - # forward annotation - for op in fwd_list: - if in_ckpt: - if isinstance(op, ForwardNograd): - ckpt_region.append(op.index) - - elif isinstance(op, ForwardEnable): - in_ckpt = False - for node_idx in ckpt_region: - for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", [ckpt_idx]) - - ckpt_idx += 1 - ckpt_region = [] - - elif isinstance(op, ForwardCheck): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", [ckpt_idx]) - - ckpt_idx += 1 - ckpt_region = [op.index] - - else: - if isinstance(op, ForwardCheck): - in_ckpt = True - ckpt_region.append(op.index) - - # annotate the backward if there is any nested activation checkpoint - in_recompute = False - for op in bwd_list: - if in_recompute: - if isinstance(op, ForwardNograd): - ckpt_region.append(op.index) - - elif isinstance(op, ForwardEnable): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - ckpt_idx += 1 - ckpt_region = [] - - elif isinstance(op, ForwardCheck): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - ckpt_idx += 1 - ckpt_region = [op.index] - - elif isinstance(op, Backward): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - in_recompute = False - - else: - if not isinstance(op, Backward): - in_recompute = True - ckpt_idx = 0 - ckpt_region = [] - if isinstance(op, ForwardCheck): - ckpt_region.append(op.index) - - # postprocess, make sure every activation checkpoint label in the - # same activation checkpoint region (level = 0) has the same length - op_list = [] - for node in node_list: - op_list += node - ckpt_regions = _find_nested_ckpt_regions(op_list) - for (start_idx, end_idx) in ckpt_regions: - nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) - for idx in range(start_idx, end_idx + 1): - op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) - - # annotate the offload - offload_idx = 0 - for idx, op in enumerate(fwd_list): - if isinstance(op, Offload): - # corner case: offload input - if op.index == 0: - if isinstance(fwd_list[idx + 1], ForwardCheck): - for n in node_list[op.index]: - setattr(n, "activation_offload", True) - else: - for n in node_list[op.index]: - setattr(n, "activation_offload", (offload_idx, True, False)) - offload_idx += 1 - - else: - if op.has_bar: - # annotate previous node - if hasattr(node_list[op.index - 1][0], "activation_offload"): - for n in node_list[op.index - 1]: - n.activation_offload[-1] = True - else: - for n in node_list[op.index - 1]: - setattr(n, "activation_offload", [offload_idx, False, True]) - - offload_idx += 1 - - # annotate this node - if isinstance(fwd_list[idx + 1], ForwardCheck): - for n in node_list[op.index]: - setattr(n, "activation_offload", True) - else: - for n in node_list[op.index]: - setattr(n, "activation_offload", [offload_idx, True, False]) - - offload_idx += 1 - - -def solver_pofo(gm: ColoGraphModule, - data, - bandwidth, - flops, - mem_limit: int, - mem_slots: int = 50, - cnode: List[str] = None, - eps: float = 0.0) -> ColoGraphModule: - """Solver that combine offload and activation checkpoint - Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html - - Args: - gm (ColoGraphModule): ColoGraphModule derived from tracer - data: input of the model - bandwidth: offload bandwidth, unit Byte/s - flops: FLOPS of device, unit FLOPs/s - mem_limit (int): memory limit, unit Byte - mem_slots (int, optional): number of memory slots. Defaults to 500. - cnode (List[str], optional): common node for linearize. Defaults to None. - eps (float, optional): epsilon for memory decay. Defaults to 0.02. - - Returns: - ColoGraphModule: annotated graph module - """ - - node_list = linearize(gm, cnode) - mem_limit -= parameter_size(gm) - - # prepare data - if is_compatible_with_meta(): - from colossalai.fx.profiler import MetaTensor - data = MetaTensor(data, fake_device=next(gm.parameters()).device) - MetaInfoProp(gm).run(data) - chain: Chain = _construct_chain(node_list, data) - chain = _normalize_flops(chain, flops) - # currently we view loss as an op without expense - chain.cbweight.append(0) - chain.cweight.append(0) - chain.fwd_mem_tmp.append(0) - chain.bwd_mem_tmp.append(0) - chain.fweight.append(0) - chain.bweight.append(0) - - solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots) - first_state = (0, 0, 0, 0, False) - sequence = solver.pofo_rec(first_state) - if sequence == None: - raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory") - - _annotate_from_pofo_sequence(sequence, node_list) - setattr(gm, "__sequence__", sequence) - return gm diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py deleted file mode 100644 index 5b8d0da9f..000000000 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ /dev/null @@ -1,436 +0,0 @@ -import math -import sys -from typing import List, Tuple - -from torch.fx import Node - -from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size -from colossalai.logging import get_dist_logger - -from .linearize import linearize -from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence - -# global vairable to indicate whether the solver is failed -SOLVER_FAILED = False - - -# this is the python compute table code from rotor -# https://gitlab.inria.fr/hiepacs/rotor -# paper link: https://hal.inria.fr/hal-02352969 -def _compute_table(chain: Chain, mmax) -> Tuple: - """Returns the optimal table: a tuple containing: - Opt[m][lmin][lmax] with lmin = 0...chain.length - and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax - what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint - (False, j) if the optimal choice is a leaf checkpoint of length j - The computation uses dynamic programming""" - - fw = chain.fweight + [0] ## forward time - bw = chain.bweight ## backward time, not used - cw = chain.cweight + [0] ## size of x (and of y) - cbw = chain.cbweight + [0] ## size of xbar - fwd_mem_tmp = chain.fwd_mem_tmp + [0] - bwd_mem_tmp = chain.bwd_mem_tmp + [0] - - # Build table - opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] - what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] - # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation - - # Initialize borders of the tables for lmax-lmin = 0 - for m in range(mmax + 1): - for i in range(chain.length + 1): - #lmax-lmin = 0 - limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i]) - if m >= limit: ## Equation (1) - opt[m][i][i] = fw[i] + bw[i] - else: - opt[m][i][i] = float("inf") - - # Compute everything - for m in range(mmax + 1): - for d in range(1, chain.length + 1): - for i in range(chain.length + 1 - d): - # for idx in range(i+1, chain.length + 1): - idx = i + d - mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i] - if idx > i + 1: - mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx))) - if m < mmin: - opt[m][i][idx] = float("inf") - else: - leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1]) - for j in range(i + 1, idx + 1) - if m >= cw[j]] - if leaf_checkpoints: - best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) - else: - best_leaf = None - if m >= cbw[i + 1]: - chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx] - else: - chain_checkpoint = float("inf") - if best_leaf and best_leaf[1] <= chain_checkpoint: - opt[m][i][idx] = best_leaf[1] - what[m][i][idx] = (False, best_leaf[0]) - else: - opt[m][i][idx] = chain_checkpoint - what[m][i][idx] = (True,) - return (opt, what) - - -def _rec(chain: Chain, lmin, lmax, cmem, opt_table): - """ chain : the class describing the AC graph - lmin : index of the first forward to execute - lmax : upper bound index of the last forward to execute (not included) - cmem : number of available memory slots - Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]""" - if cmem <= 0: - raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem)) - opt, what = opt_table - sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) - if opt[cmem][lmin][lmax] == float("inf"): - # using logger to annonce that the solver is failed - logger = get_dist_logger() - logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, - lmax=lmax, - cmem=cmem)) - - # set global indicater SOLVER_FAILED to True - global SOLVER_FAILED - SOLVER_FAILED = True - return sequence - - if lmin == lmax: - if lmin == chain.length: - sequence.insert(Loss()) - else: - sequence.insert(ForwardEnable(lmin)) - sequence.insert(Backward(lmin)) - return sequence - - if what[cmem][lmin][lmax][0]: - sequence.insert(ForwardEnable(lmin)) - sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table)) - sequence.insert(Backward(lmin)) - else: - j = what[cmem][lmin][lmax][1] - sequence.insert(ForwardCheck(lmin)) - for k in range(lmin + 1, j): - sequence.insert(ForwardNograd(k)) - sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table)) - sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) - return sequence - - -def _fwd_xbar(node: List[Node]) -> int: - """Get the forward xbar of a node - - Args: - node (List[Node]): List of torch.fx Node, - indicates a node in linearized graph - - Returns: - int: xbar size, unit Byte - """ - - xbar = 0 - for n in node: - xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) - return xbar - - -def _fwd_time(node: List[Node]) -> int: - """Get the foward time of a node - - Args: - node (List[Node]): List of torch.fx Node, - indicates a node in linearized graph - - Returns: - int: foward time, extimated by flops count - """ - - fwd_time = 0 - for n in node: - # minimum flop count is needed - fwd_time += max(n.meta['fwd_flop'], 1) - return fwd_time - - -def _bwd_time(node: List[Node]) -> int: - """Get the backward time of a node - - Args: - node (List[Node]): List of torch.fx Node, - indicates a node in linearized graph - - Returns: - int: backward time, extimated by flops count - """ - - bwd_time = 0 - for n in node: - # minimum flop count is needed - bwd_time += max(n.meta['bwd_flop'], 1) - return bwd_time - - -def _get_fwd_mem_tmp(node: List[Node]) -> int: - """Get the forward temp memory of a node - This could be done by subtracting the saved activation from all output of a node - - Args: - node (List[Node]): List of torch.fx Node, - indicates a node in linearized graph - - Returns: - int: forward temp memory, unit Byte - """ - n = node[-1] - return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) - - -def _get_bwd_mem_tmp(node: List[Node]) -> int: - """Get the backward temp memory of a node - - Args: - node (List[Node]): List of torch.fx Node, - indicates a node in linearized graph - - Returns: - int: backward temp memory, unit Byte - """ - - def _get_deps_size(): - deps_size = 0 - for k, v in deps.items(): - k: Node - if v > 0: - deps_size += k.meta['bwd_mem_out'] - if v == float('-inf'): - deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) - - return deps_size - - bwd_mem_tmp = 0 - deps = {} - - for n in reversed(node): - deps[n] = len(n.all_input_nodes) - bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) - - for child in n.users: - if child in deps: - deps[child] -= 1 - if deps[child] <= 0: - deps[child] = float('-inf') # free - - return bwd_mem_tmp - - -def _construct_chain(node_list: List[List[Node]], input) -> Chain: - - fwd_time = [] - bwd_time = [] - xbar_sizes = [activation_size(input)] - x_sizes = [activation_size(input)] - tmp_fwd = [] - tmp_bwd = [] - - for idx, node in enumerate(node_list): - fwd_time.append(_fwd_time(node)) - bwd_time.append(_bwd_time(node)) - x_sizes.append(calculate_fwd_out(node[-1])) - xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) - tmp_fwd.append(_get_fwd_mem_tmp(node)) - tmp_bwd.append(_get_bwd_mem_tmp(node)) - - bwd_time.append(0) - - # currently we view loss backward temp as zero - tmp_bwd.append(0) - - return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) - - -def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): - op_list = sequence.list_operations() - loss_op = next(op for op in op_list if isinstance(op, Loss)) - fwd_list = op_list[:op_list.index(loss_op)] - bwd_list = op_list[op_list.index(loss_op) + 1:] - ckpt_idx = 0 - in_ckpt = False - ckpt_region = [] - - # forward annotation - for idx, op in enumerate(fwd_list, 0): - if in_ckpt: - if isinstance(op, ForwardNograd): - ckpt_region.append(idx) - - elif isinstance(op, ForwardEnable): - in_ckpt = False - for node_idx in ckpt_region: - for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", [ckpt_idx]) - - ckpt_idx += 1 - ckpt_region = [] - - elif isinstance(op, ForwardCheck): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", [ckpt_idx]) - - ckpt_idx += 1 - ckpt_region = [idx] - - else: - if isinstance(op, ForwardCheck): - in_ckpt = True - ckpt_region.append(idx) - - # annotate the backward if there is any nested activation checkpoint - in_recompute = False - for op in bwd_list: - if in_recompute: - if isinstance(op, ForwardNograd): - ckpt_region.append(op.index) - - elif isinstance(op, ForwardEnable): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - ckpt_idx += 1 - ckpt_region = [] - - elif isinstance(op, ForwardCheck): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - ckpt_idx += 1 - ckpt_region = [op.index] - - elif isinstance(op, Backward): - for node_idx in ckpt_region: - for n in node_list[node_idx]: - n.activation_checkpoint.append(ckpt_idx) - - in_recompute = False - - else: - if not isinstance(op, Backward): - in_recompute = True - ckpt_idx = 0 - ckpt_region = [] - if isinstance(op, ForwardCheck): - ckpt_region.append(op.index) - - # postprocess, make sure every activation checkpoint label in the - # same activation checkpoint region (level = 0) has the same length - op_list = [] - for node in node_list: - op_list += node - ckpt_regions = _find_nested_ckpt_regions(op_list) - for (start_idx, end_idx) in ckpt_regions: - nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) - for idx in range(start_idx, end_idx + 1): - op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) - - -def solver_rotor(gm: ColoGraphModule, - data, - mem_limit: int, - mem_slots: int = 500, - cnode: List[str] = None, - eps: float = 0.0, - force_python: bool = False) -> ColoGraphModule: - """solver that automatically find activation checkpoint in rotor's manner - - Args: - gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp. - data (torch.Tensor): input data. - mem_limit (int): memory budget in Byte. - mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. - cnode (List[Node], optional): common node list for linearize. Defaults to None. - eps (float): epsilon for memory decay. Defaults to 0.0 - force_python (bool): force to use python version of dynamic programs - - Returns: - ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute - """ - - # try to import C version solver if force_python is not set - logger = get_dist_logger() - if not force_python: - try: - from .dynamic_programs_C_version import persistent_compute_table - CVERSION = True - - # build module if module not found - except ModuleNotFoundError: - import os - import subprocess - logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) - this_dir = os.path.dirname(os.path.abspath(__file__)) - result = subprocess.Popen( - [ - f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", - f"--build-lib={this_dir}" - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if result.wait() == 0: - logger.info("dynamic_programs_C_version has been built!", ranks=[0]) - from .dynamic_programs_C_version import persistent_compute_table - CVERSION = True - else: - logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0]) - CVERSION = False - else: - CVERSION = False - - # check if metainfoprop is done - if any(len(node.meta) == 0 for node in gm.graph.nodes): - raise RuntimeError( - "Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!") - - # linearize the graph - node_list = linearize(gm, cnode) - - # construct chain - mem_unit = mem_limit * (1.0 - eps) // mem_slots - chain: Chain = _construct_chain(node_list, data) - chain._discretize(mem_unit) - - # use C version if possible - if CVERSION and not force_python: - logger.info("Using C version rotor solver!", ranks=[0]) - opt_table = persistent_compute_table(chain, mem_slots) - else: - opt_table = _compute_table(chain, mem_slots) - logger.info("Using python version rotor solver!", ranks=[0]) - - # found sequence - sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) - - # if solver failed, we don't need to annotate the graph - if not SOLVER_FAILED: - _annotate_from_sequence(sequence, node_list) - - # set __sequence__ attribute to GraphModule - if SOLVER_FAILED: - setattr(gm, "__sequence__", None) - else: - setattr(gm, "__sequence__", sequence) - - # set __opttable__ attribute to GraphModule - setattr(gm, "__opttable__", opt_table[0]) - gm.recompile() - return gm diff --git a/colossalai/fx/passes/algorithms/dynamic_programs.c b/colossalai/fx/passes/algorithms/dynamic_programs.c deleted file mode 100644 index 3efad5840..000000000 --- a/colossalai/fx/passes/algorithms/dynamic_programs.c +++ /dev/null @@ -1,516 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include - -long* PySequenceToLongArray(PyObject* pylist) { - if (!(pylist && PySequence_Check(pylist))) return NULL; - Py_ssize_t len = PySequence_Size(pylist); - long* result = (long*)calloc(len + 1, sizeof(long)); - for (Py_ssize_t i = 0; i < len; ++i) { - PyObject* item = PySequence_GetItem(pylist, i); - result[i] = PyLong_AsLong(item); - Py_DECREF(item); - } - result[len] = 0; - return result; -} - -double* PySequenceToDoubleArray(PyObject* pylist) { - if (!(pylist && PySequence_Check(pylist))) return NULL; - Py_ssize_t len = PySequence_Size(pylist); - double* result = (double*)calloc(len + 1, sizeof(double)); - for (Py_ssize_t i = 0; i < len; ++i) { - PyObject* item = PySequence_GetItem(pylist, i); - result[i] = PyFloat_AsDouble(item); - Py_DECREF(item); - } - result[len] = 0; - return result; -} - -long* getLongArray(PyObject* container, const char* attributeName) { - PyObject* sequence = PyObject_GetAttrString(container, attributeName); - long* result = PySequenceToLongArray(sequence); - Py_DECREF(sequence); - return result; -} - -double* getDoubleArray(PyObject* container, const char* attributeName) { - PyObject* sequence = PyObject_GetAttrString(container, attributeName); - double* result = PySequenceToDoubleArray(sequence); - Py_DECREF(sequence); - return result; -} - -static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { - PyObject* chain_param; - int mmax; - - if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; - - double* fw = getDoubleArray(chain_param, "fweight"); - if (!fw) return NULL; - - double* bw = getDoubleArray(chain_param, "bweight"); - if (!bw) return NULL; - - long* cw = getLongArray(chain_param, "cweight"); - if (!cw) return NULL; - - long* cbw = getLongArray(chain_param, "cbweight"); - if (!cbw) return NULL; - - long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp"); - if (!cbw) return NULL; - - long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp"); - if (!cbw) return NULL; - - PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); - if (!chain_length_param) return NULL; - long chain_length = PyLong_AsLong(chain_length_param); - Py_DECREF(chain_length_param); - - // TODO: Can be optimized by only allocating memory for l >= i - // TODO: float / int instead of double / long ? -#define OPT(m, i, l) \ - opt[(m) * (chain_length + 1) * (chain_length + 1) + \ - (i) * (chain_length + 1) + (l)] - double* opt = (double*)calloc( - (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); - -#define WHAT(m, i, l) \ - what[(m) * (chain_length + 1) * (chain_length + 1) + \ - (i) * (chain_length + 1) + (l)] - long* what = (long*)calloc( - (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long)); - - for (long m = 0; m <= mmax; ++m) - for (long i = 0; i <= chain_length; ++i) - // TODO: Can be optimized to remove the IF by reordering loops - if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && - (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) - OPT(m, i, i) = fw[i] + bw[i]; - else - OPT(m, i, i) = INFINITY; - - for (long m = 0; m <= mmax; ++m) - for (long d = 1; d <= chain_length; ++d) { - for (long i = 0; i <= chain_length - d; ++i) { - long idx = i + d; - long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]; - if (idx > i + 1) { - long maxCostFWD = 0; - for (long j = i + 1; j < idx; j++) { - maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]); - } - mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD); - } - if ((m >= mmin)) { - long bestLeaf = -1; - double sumFw = 0; - double bestLeafCost = INFINITY; - /// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j = - /// i+1 - for (long j = i + 1; j <= idx; ++j) { - sumFw += fw[j - 1]; - if (m >= cw[j]) { - double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1); - if (cost < bestLeafCost) { - bestLeafCost = cost; - bestLeaf = j; - } - } - } - double chainCost = INFINITY; - if (m >= cbw[i + 1]) - chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx); - if (bestLeafCost <= chainCost) { - OPT(m, i, idx) = bestLeafCost; - WHAT(m, i, idx) = bestLeaf; - } else { - OPT(m, i, idx) = chainCost; - WHAT(m, i, idx) = -1; - } - } else - OPT(m, i, idx) = INFINITY; - } - } - - free(fw); - free(bw); - free(cw); - free(cbw); - free(fwd_tmp); - free(bwd_tmp); - - PyObject* res_opt = PyList_New(mmax + 1); - PyObject* res_what = PyList_New(mmax + 1); - - // Convert the result into Python world - for (long m = 0; m <= mmax; ++m) { - PyObject* res_opt_m = PyList_New(chain_length + 1); - PyList_SET_ITEM(res_opt, m, res_opt_m); - PyObject* res_what_m = PyList_New(chain_length + 1); - PyList_SET_ITEM(res_what, m, res_what_m); - for (long i = 0; i <= chain_length; ++i) { - PyObject* res_opt_m_i = PyDict_New(); - PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); - PyObject* res_what_m_i = PyDict_New(); - PyList_SET_ITEM(res_what_m, i, res_what_m_i); - for (long l = i; l <= chain_length; ++l) { - PyObject* res_l = PyLong_FromLong(l); - PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); - PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); - Py_DECREF(res_opt_m_i_l); - PyObject* res_what_m_i_l; - long what_m_i_l = WHAT(m, i, l); - if (what_m_i_l < 0) - res_what_m_i_l = Py_BuildValue("(O)", Py_True); - else - res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l); - PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l); - Py_DECREF(res_what_m_i_l); - Py_DECREF(res_l); - } - } - } - - free(opt); - free(what); - - PyObject* result = PyTuple_Pack(2, res_opt, res_what); - Py_DECREF(res_opt); - Py_DECREF(res_what); - return result; -} - -// long i = L - s, j = t - s, k = l - t -inline long floating_index_in_array(long m_factor, long m, long i, long j, - long k) { - return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j - - (j * (j - 1)) / 2 + k; -} - -typedef struct { - long sp; - long r; - long tp; -} index_t; - -static PyObject* floating_compute_table(PyObject* self, PyObject* args) { - PyObject* chain_param; - int mmax; - - if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; - - double* fw = getDoubleArray(chain_param, "fweigth"); - if (!fw) return NULL; - - double* bw = getDoubleArray(chain_param, "bweigth"); - if (!bw) return NULL; - - long* cw = getLongArray(chain_param, "cweigth"); - if (!cw) return NULL; - - long* cbw = getLongArray(chain_param, "cbweigth"); - if (!cbw) return NULL; - - long* fwd_tmp = getLongArray(chain_param, "fwd_tmp"); - if (!fwd_tmp) return NULL; - - long* bwd_tmp = getLongArray(chain_param, "bwd_tmp"); - if (!bwd_tmp) return NULL; - - PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); - if (!chain_length_param) return NULL; - long chain_length = PyLong_AsLong(chain_length_param); - Py_DECREF(chain_length_param); - - const long m_factor = - (chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12; - - // Defined for 0 <= s <= t <= l <= chain_length, for all m -#undef OPT -#define OPT(m, s, t, l) \ - opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ - (l) - (t))] - double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double)); - -#undef WHAT -#define WHAT(m, s, t, l) \ - what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ - (l) - (t))] - index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t)); - - double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double)); - double total = 0; - for (long i = 0; i < chain_length; ++i) { - partialSumsFW[i] = total; - total += fw[i]; - } - partialSumsFW[chain_length] = total; - - for (long m = 0; m <= mmax; ++m) - for (long i = 0; i <= chain_length; ++i) { - // TODO: Can be optimized to remove the IF by reordering loops - if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && - (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) - OPT(m, i, i, i) = fw[i] + bw[i]; - else - OPT(m, i, i, i) = INFINITY; - } - - for (long m = 0; m <= mmax; ++m) - for (long d = 1; d <= chain_length; ++d) { // d = l - s - for (long s = 0; s <= chain_length - d; ++s) { - long l = s + d; - long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s]; - long memNullSecond = 0; - for (long j = s + 1; j < l; ++j) { - long val = cw[j] + cw[j + 1] + fwd_tmp[j]; - if (val > memNullSecond) memNullSecond = val; - } - for (long t = s; t <= l; ++t) { - double chainCost = INFINITY; - if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) && - (m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) { - chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l); - } - double bestLeafCost = INFINITY; - index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1}; - if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) { - for (long r = s; r <= t; ++r) - if (cw[s] <= cw[r]) - for (long tp = t + 1; tp <= l; ++tp) - for (long sp = r + 1; sp <= tp; ++sp) { - long mp = m - cw[r] + cw[s]; - assert(mp >= 0); - if (mp >= cw[sp]) { - double value = partialSumsFW[sp] - partialSumsFW[s] + - OPT(mp - cw[sp], sp, tp, l) + - OPT(mp, r, t, tp - 1); - if (value < bestLeafCost) { - bestLeafCost = value; - bestLeaf.sp = sp; - bestLeaf.r = r; - bestLeaf.tp = tp; - } - } - } - } - if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) { - OPT(m, s, t, l) = bestLeafCost; - WHAT(m, s, t, l).sp = bestLeaf.sp; - WHAT(m, s, t, l).r = bestLeaf.r; - WHAT(m, s, t, l).tp = bestLeaf.tp; - } else { - OPT(m, s, t, l) = chainCost; - WHAT(m, s, t, l).sp = -1; - } - } - } - } - - free(fw); - free(bw); - free(cw); - free(cbw); - free(fwd_tmp); - free(bwd_tmp); - - PyObject* res_opt = PyList_New(mmax + 1); - PyObject* res_what = PyList_New(mmax + 1); - - // Convert the result into Python world - PyObject* true_tuple = Py_BuildValue("(O)", Py_True); - for (long m = 0; m <= mmax; ++m) { - PyObject* res_opt_m = PyDict_New(); - PyList_SET_ITEM(res_opt, m, res_opt_m); - PyObject* res_what_m = PyDict_New(); - PyList_SET_ITEM(res_what, m, res_what_m); - for (long s = 0; s <= chain_length; ++s) - for (long t = s; t <= chain_length; ++t) - for (long l = t; l <= chain_length; ++l) { - PyObject* key = Py_BuildValue("(lll)", s, t, l); - PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l)); - PyDict_SetItem(res_opt_m, key, value_opt); - PyObject* value_what = true_tuple; - index_t* idx_what = &WHAT(m, s, t, l); - if (idx_what->sp >= 0) - value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp, - idx_what->r, idx_what->tp); - PyDict_SetItem(res_what_m, key, value_what); - if (value_what != true_tuple) Py_DECREF(value_what); - Py_DECREF(key); - Py_DECREF(value_opt); - } - } - - Py_DECREF(true_tuple); - - free(opt); - free(what); - - PyObject* result = PyTuple_Pack(2, res_opt, res_what); - Py_DECREF(res_opt); - Py_DECREF(res_what); - return result; -} - -static PyObject* griewank_heterogeneous_compute_table(PyObject* self, - PyObject* args) { - PyObject* chain_param; - int mmax; - - if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; - - double* fw = getDoubleArray(chain_param, "fweigth"); - if (!fw) return NULL; - - double* bw = getDoubleArray(chain_param, "bweigth"); - if (!bw) return NULL; - - long* cw = getLongArray(chain_param, "cweigth"); - if (!cw) return NULL; - - long* cbw = getLongArray(chain_param, "cbweigth"); - if (!cbw) return NULL; - - PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); - if (!chain_length_param) return NULL; - long chain_length = PyLong_AsLong(chain_length_param); - Py_DECREF(chain_length_param); - - // TODO: Can be optimized by only allocating memory for l >= i - // TODO: float / int instead of double / long ? -#undef OPT -#define OPT(m, i, l) \ - opt[(m) * (chain_length + 1) * (chain_length + 1) + \ - (i) * (chain_length + 1) + (l)] - double* opt = (double*)calloc( - (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); - - // Compute partial sums - double* sumfw = (double*)calloc(chain_length, sizeof(double)); - double* sumbw = (double*)calloc(chain_length + 1, sizeof(double)); - double* sumsumfw = (double*)calloc(chain_length, sizeof(double)); - - double total = 0; - for (long i = 0; i < chain_length; ++i) { - total += fw[i]; - sumfw[i] = total; - } - - total = 0; - for (long i = 0; i < chain_length + 1; ++i) { - total += bw[i]; - sumbw[i] = total; - } - - total = 0; - for (long i = 0; i < chain_length; ++i) { - total += sumfw[i]; - sumsumfw[i] = total; - } - - for (long m = 0; m <= mmax; ++m) - for (long i = 0; i <= chain_length; ++i) { - // TODO: Can be optimized to remove the IF by reordering loops - if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1])) - OPT(m, i, i) = bw[i]; - else - OPT(m, i, i) = INFINITY; - - if (i < chain_length) { - long maxC = fmaxl(cw[i], cw[i + 1]); - long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC); - if ((m >= cbw[i]) && (m >= cw[i] + maxCB)) - OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1]; - else - OPT(m, i, i + 1) = INFINITY; - } - } - - for (long m = 0; m <= mmax; ++m) - for (long i = 0; i + 2 <= chain_length; ++i) { - long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]); - long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]); - long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il; - for (long l = i + 2; l <= chain_length; ++l) { - maxCW_il = fmax(maxCW_il, cw[l + 1]); - maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il); - long mmin = fmaxl(mminCst, maxCostFWD); - if ((m >= mmin)) { - double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0); - noCheckpointCost += - sumsumfw[l - 1] - - (i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0); - - double valueCost = INFINITY; - if (m >= cw[i]) { - double sumFwds = 0; - for (long j = i + 1; j < l; ++j) { - sumFwds += fw[j - 1]; - valueCost = fmin( - valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1)); - } - } - OPT(m, i, l) = fmin(noCheckpointCost, valueCost); - } else - OPT(m, i, l) = INFINITY; - } - } - - free(sumfw); - free(sumbw); - free(sumsumfw); - free(fw); - free(bw); - free(cw); - free(cbw); - - PyObject* res_opt = PyList_New(mmax + 1); - - // Convert the result into Python world - for (long m = 0; m <= mmax; ++m) { - PyObject* res_opt_m = PyList_New(chain_length + 1); - PyList_SET_ITEM(res_opt, m, res_opt_m); - for (long i = 0; i <= chain_length; ++i) { - PyObject* res_opt_m_i = PyDict_New(); - PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); - for (long l = i; l <= chain_length; ++l) { - PyObject* res_l = PyLong_FromLong(l - i); - PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); - PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); - Py_DECREF(res_opt_m_i_l); - Py_DECREF(res_l); - } - } - } - - free(opt); - - return res_opt; -} - -static PyMethodDef dynamic_programs_methods[] = { - {"persistent_compute_table", persistent_compute_table, METH_VARARGS, - "Compute the optimal table with the persistent algorithm."}, - {"floating_compute_table", floating_compute_table, METH_VARARGS, - "Compute the optimal table with the floating algorithm."}, - {"griewank_heterogeneous_compute_table", - griewank_heterogeneous_compute_table, METH_VARARGS, - "Compute the optimal table for the Griewank Heterogeneous Model."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -static struct PyModuleDef dynamic_programs_module = { - PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */ - NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - dynamic_programs_methods}; - -PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) { - return PyModule_Create(&dynamic_programs_module); -} diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py deleted file mode 100644 index 1a49364f5..000000000 --- a/colossalai/fx/passes/algorithms/linearize.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import List, Any -from torch.fx import GraphModule, Node -from colossalai.fx.profiler import is_inplace - -# Common nodes are type of nodes that could be seen as attributes and remain -# unchanged throughout the whole model, it will be used several times by -# different blocks of model, so that it is hard for us to linearize the graph -# when we encounter those kinds of nodes. We let users to annotate some of the -# input as common node, such as attention mask, and the followings are some of -# the ops that could actually be seen as common nodes. With our common node prop, -# we could find some of the "real" common nodes (e.g. the real attention mask -# used in BERT and GPT), the rule is simple, for node who's parents are all common -# nodes or it's op belongs to the following operations, we view this node as a -# newly born common node. -# List of target name that could be seen as common node -COPS = ["getattr", "getitem", "size"] - - -def _is_cop(target: Any) -> bool: - """Check if an op could be seen as common node - - Args: - target (Any): node target - - Returns: - bool - """ - - if isinstance(target, str): - return target in COPS - else: - return target.__name__ in COPS - - -def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: - """Linearizing the graph - - Args: - gm (GraphModule): GraphModule derived by tracing - cnode (List[str], optional): common node List, should be the subset of input. Default to None. - - Returns: - List[List[Node]]: List of list, each inside list of Node presents - the actual 'node' in linearized manner. - - Remarks: - We merge the inplace ops into the previous node. - """ - - def _is_sink() -> bool: - """Check if we can free all dependencies - - Returns: - bool - """ - - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) - - # make sure that item in cnode is valid - if cnode: - for name in cnode: - try: - assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \ - f"common node {name} is not an input of the model" - except StopIteration: - raise ValueError(f"common node name {name} not in graph") - - else: - cnode = [] - - deps = {} - linearized_nodes = [] - region = [] - - for n in gm.graph.nodes: - if n.op != "placeholder" and n.op != "output": - for n_par in n._input_nodes: - if n_par.op != "placeholder" and n_par.name not in cnode: - deps[n_par] -= 1 - region.append(n) - - # if the node could free all dependencies in graph - # we could begin a new node - if _is_sink(): - linearized_nodes.append(region) - region = [] - - # propagate common node attr if possible - if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target): - cnode.append(n.name) - else: - deps[n] = len([user for user in n.users if user.op != "output"]) - - return linearized_nodes diff --git a/colossalai/fx/passes/algorithms/operation.py b/colossalai/fx/passes/algorithms/operation.py deleted file mode 100644 index 8bfa3452b..000000000 --- a/colossalai/fx/passes/algorithms/operation.py +++ /dev/null @@ -1,270 +0,0 @@ -import math - - -def _discretize(mem_unit, values): - return [math.ceil(value / mem_unit) for value in values] - - -class Chain: - - def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): - self.fweight = fw - self.bweight = bw - self.cweight = cw - self.cbweight = cbw - self.fwd_mem_tmp = ftmp - self.bwd_mem_tmp = btmp - self.length = len(fw) - if check and not self.check_lengths(): - raise AttributeError("In Chain, input lists do not have consistent lengths") - - def check_lengths(self): - return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) - and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length) - and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) - - def __repr__(self): - chain_list = [] - for i in range(self.length): - chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i], - self.bwd_mem_tmp[i])) - i = self.length - chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i])) - return chain_list.__repr__() - - def _discretize(self, mem_unit): - self.cweight = _discretize(mem_unit, self.cweight) - self.cbweight = _discretize(mem_unit, self.cbweight) - self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp) - self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp) - - -class Operation: - - def shift(self, value): - if type(self.index) is tuple: - self.index = tuple(x + value for x in self.index) - else: - self.index += value - - -class Offload(Operation): - - def __init__(self, index, has_bar=False) -> None: - super().__init__() - self.index = index - self.name = "Off" - self.has_bar = has_bar - if self.has_bar: - self.name += "wBar" - - def __repr__(self): - return f"{self.name}_{self.index}" - - -class Prefetch(Operation): - - def __init__(self, index, has_bar=False) -> None: - super().__init__() - self.index = index - self.name = "Pre" - self.has_bar = has_bar - if self.has_bar: - self.name += "wBar" - - def __repr__(self): - return f"{self.name}_{self.index}" - - -class Forward(Operation): - - def __init__(self, index): - self.index = index - self.name = "F" - - def __repr__(self): - return "{n}_{i}".format(n=self.name, i=self.index) - - def cost(self, chain: Chain): - if chain is not None: - return chain.fweight[self.index] - else: - return 1 - - -class ForwardEnable(Forward): - - def __init__(self, index): - super().__init__(index) - self.name = "Fe" - - -class ForwardNograd(Forward): - - def __init__(self, index): - super().__init__(index) - self.name = "Fn" - - -class ForwardCheck(Forward): - - def __init__(self, index): - super().__init__(index) - self.name = "CF" - - -class Forwards(Operation): - - def __init__(self, start, end): - self.index = (start, end) - - def __repr__(self): - return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) - - def cost(self, chain: Chain): - if chain is not None: - return sum(chain.fweight[self.index[0]:self.index[1] + 1]) - else: - return (self.index[1] - self.index[0] + 1) - - -def isForward(op): - return type(op) is Forward or type(op) is Forwards - - -class Backward(Operation): - - def __init__(self, index): - self.index = index - - def __repr__(self): - return "B_{i}".format(i=self.index) - - def cost(self, chain: Chain): - if chain is not None: - return chain.bweight[self.index] - else: - return 1 - - -class Loss(Operation): - - def __init__(self): - pass - - def __repr__(self): - return "L" - - def cost(self, chain): - return 0 - - -class MemoryAccess(Operation): - - def __init__(self, index): - self.index = index - - def __repr__(self): - return "{n}_{i}".format(n=self.name, i=self.index) - - def cost(self, chain: Chain): - return 0 - - -class WriteMemory(MemoryAccess): - - def __init__(self, index): - super().__init__(index) - self.name = "WM" - - -class ReadMemory(MemoryAccess): - - def __init__(self, index): - super().__init__(index) - self.name = "RM" - - -class DiscardMemory(MemoryAccess): - - def __init__(self, index): - super().__init__(index) - self.name = "DM" - - -class Function: - - def __init__(self, name, *args): - self.name = name - self.args = args - self.str_args = ','.join(str(v) for v in self.args) - - def __repr__(self): - return "{n}({args})".format(n=self.name, args=self.str_args) - - -class Sequence: - - def __init__(self, function): - self.sequence = [] #List of Operation and Sequence - self.function = function #Description the function (name and parameters) - - def __repr__(self): - return repr(self.list_operations()) - - def list_operations(self): - op_list = [] - for x in self.sequence: - if isinstance(x, Operation): - op_list.append(x) - else: - assert isinstance(x, Sequence) - op_list += x.list_operations() - return op_list - - def insert(self, operation): - self.sequence.append(operation) - - def remove(self, operation_index): - del self.sequence[operation_index] - - def insert_sequence(self, sequence): - self.sequence.append(sequence) - - def shift(self, value): - for x in self.sequence: - x.shift(value) - return self - - def remove_useless_write(self): - if self.sequence: - if isinstance(self.sequence[0], WriteMemory): - self.remove(0) - return self - - def get_makespan(self, chain): - return sum(op.cost(chain) for op in self.list_operations()) - - def without_suffix(self): - ops = self.list_operations() - end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] - try: - last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable) - except ValueError: - last_idx = -1 - if last_idx == end_of_first_phase - 1: - return (self, None) - chain_length = ops[end_of_first_phase - - 1].index ## Some assumption here about the sequence (finishes with Forward_L - start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice - result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain)) - for i in range(last_idx + 1): - result.insert(ops[i]) - result.insert(Loss()) - for i in range(chain_length, start_of_fwd_enable_chain - 1, -1): - position = end_of_first_phase + 1 + (chain_length - i) - assert type(ops[position]) is Backward - assert ops[position].index == i - for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)): - result.insert(ops[i]) - return (result, start_of_fwd_enable_chain)