import copy import math from typing import List, Tuple import torch from colossalai.fx import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import \ _find_nested_ckpt_regions from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec) from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import parameter_size from torch.fx import GraphModule, Node from .linearize import linearize from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch, Sequence) INF = float("inf") def _normalize_flops(chain: Chain, flops) -> Chain: """ Normalize flops """ for i in range(chain.length): chain.fweight[i] /= flops chain.bweight[i] /= flops return chain class PofoTable: """PofoTable The PofoTable contains the necessary components to store intermediate results of dynamic programming and the operations alone the way. """ def __init__(self, chain_length: int, mem_slots: int): """Init pofo table The pofo table contains two tables, opt and what, indicating values and operations. Args: chain_length (int): chain length mem_slots (int): number of memory slots """ self.length = chain_length self.mem_slots = mem_slots # initializing tables # the first bool indicates whether the input has bar # opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db) # what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index) # where is_enable indicates whether we enable the gradient, is_offload indicates whether we # offload the input, index indicates the end of F_\empty sequence if is_enable = False self.opt = { False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] } self.what = { False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] } def _get_value(self, state, table, default): i, act_size, df, db, input_has_bar = state if act_size + df > self.mem_slots or act_size + db > self.mem_slots: return default try: return table[input_has_bar][i][act_size][(df, db)] except KeyError: print(f"state not found {state}") def get_opt(self, state): return self._get_value(state, self.opt, INF) def get_what(self, state): return self._get_value(state, self.what, INF) def set_value(self, state, opt, what): i, act_size, df, db, input_has_bar = state self.opt[input_has_bar][i][act_size][(df, db)] = opt self.what[input_has_bar][i][act_size][(df, db)] = what class PofoSolver: """PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload. """ def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None: self.chain = chain self.length = chain.length self.max_memory = max_memory self.mem_slots = mem_slots self.mem_unit = max_memory / mem_slots self.bandwidth = bandwidth self.disc_chain = copy.deepcopy(self.chain) self.disc_chain._discretize(self.mem_unit) self.rotor_table = _compute_table(self.disc_chain, mem_slots) self._compute_pofo_table() def _discretize(self, *values) -> Tuple: return tuple(math.ceil(value / self.mem_unit) for value in values) def _undiscretize(self, *discrete_values) -> Tuple: if len(discrete_values) == 1: return discrete_values[0] * self.mem_unit else: return tuple(d * self.mem_unit for d in discrete_values) def _mmax_all(self, idx: int): """ Calculate the maximum memory usage of Fi_all """ return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx] def _mmax_b(self, idx: int): """ Calculate the maximum memory usage of Bi """ return self.chain.cbweight[idx + 1] + self.chain.cweight[idx + 1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx] def _mmax_ng(self, i: int, j: int): """ Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty """ res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j] if j > i: res += self.chain.cweight[j] return res def _rotor_estimated_bwd(self, i, j, m, delta): compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j] comm = delta / self.bandwidth return (max(compute, comm) + compute + comm) / 2 def _rotor_estimated_bwd_sequence(self, i, j, m, delta): return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table) def _common_values_enable(self, state: Tuple): idx, act_size, df, db, input_has_bar = state input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] mf = act_size + df + input_size mb = act_size + db + input_size mem_avail = self.max_memory - act_size - input_size f_usage = self._mmax_all(idx) b_usage = self._mmax_b(idx) # infeasible if f_usage > mem_avail or b_usage > mem_avail: return None # calculate idle time eps_f_beta = max(0, f_usage - self.max_memory + mf) eps_b_beta = max(0, b_usage - self.max_memory + mb) idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth # calculate offload and prefetch data offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta # total_time total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time return (offload_data, prefetch_data, total_time, idle_time) def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False): i, act_size, df, db, input_has_bar = state # compute new epsilon_tmp and sum_fwds if iterative: self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds) self.sum_fwds += self.chain.fweight[j] else: self.epsilon_tmp = max( self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1)) self.sum_fwds = sum(self.chain.fweight[i:j + 1]) input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i] mf = act_size + df + input_size mem_avail = self.max_memory - act_size - input_size # if infeasible if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail: return None eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf) offload_data = self.sum_fwds * self.bandwidth + eps_f_beta # TODO: Implement the precise backward recompute sequence mentioned in the paper # currently we will use an approximate way to get the backward time time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db) prefetch_data = time_backward * self.bandwidth idle_time = eps_f_beta / self.bandwidth total_time = self.sum_fwds + idle_time + time_backward return (offload_data, prefetch_data, total_time, idle_time) def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple: """Generate new values for next state Args: state (Tuple): undiscretized states do_offload (bool): bool type indicates whether we need to do offload common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time) Returns: Tuple: (new_act_size, new_df, new_db) """ idx, act_size, df, db, input_has_bar = state offload_data, prefetch_data, *_ = common_values input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] if do_offload: new_act_size = act_size new_df = max(0, df + input_size - offload_data) new_db = max(0, db - prefetch_data) + input_size else: new_act_size = act_size + input_size new_df = max(0, df - offload_data) new_db = max(0, db - prefetch_data) return (new_act_size, new_df, new_db) def _compute_pofo_table(self): self.table = PofoTable(self.length, self.mem_slots) # initializing the loss for act_size in range(self.mem_slots + 1): for df in range(self.mem_slots - act_size + 1): for db in range(self.mem_slots - act_size + 1): # undiscretize for idle time calculation origin_values = self._undiscretize(act_size, df, db) for input_has_bar in (False, True): disc_state = (self.length, act_size, df, db, input_has_bar) state = (self.length, *origin_values, input_has_bar) common_values = self._common_values_enable(state) # if no feasible choice if common_values is None: self.table.set_value(disc_state, INF, None) continue # if there is feasible choice new_act_size, new_df, new_db = self._new_values(state, False, common_values) eps_g = (new_df + new_db) / self.bandwidth total_time = common_values[2] + eps_g self.table.set_value(disc_state, total_time, (True, False)) # main loop for i in reversed(range(self.length)): for act_size in range(self.mem_slots + 1): for df in range(self.mem_slots - act_size + 1): for db in range(self.mem_slots - act_size + 1): # undiscretize for idle time calculation origin_values = self._undiscretize(act_size, df, db) for input_has_bar in (False, True): best_result = INF best_choice = None disc_state = (i, act_size, df, db, input_has_bar) state = (i, *origin_values, input_has_bar) # case 1: start with F_all vals_enable = self._common_values_enable(state) if vals_enable is not None: for do_offload in (True, False): new_state = self._new_values(state, do_offload, vals_enable) new_state = (i + 1, *self._discretize(*new_state), True) total_time = vals_enable[2] results_all = self.table.get_opt(new_state) + total_time if results_all < best_result: best_result = results_all best_choice = (True, do_offload) # case 2: start with F_ck self.sum_fwds = 0 self.epsilon_tmp = 0 for j in range(i, self.length): vals_nograd = self._common_values_nograd(state, j, True) # if infeasible if vals_nograd is None: continue for do_offload in (True, False): new_state = self._new_values(state, do_offload, vals_nograd) new_state = (j + 1, *self._discretize(*new_state), False) total_time = vals_nograd[2] result_nograd = total_time + self.table.get_opt(new_state) if result_nograd < best_result: best_result = result_nograd best_choice = (False, do_offload, j) self.table.set_value(disc_state, best_result, best_choice) def pofo_rec(self, disc_state): i, act_size, df, db, input_has_bar = disc_state result = Sequence(Function("pofo", *disc_state)) what = self.table.get_what(disc_state) state = self._undiscretize(act_size, df, db) state = (i, *state, input_has_bar) i, act_size, df, db, input_has_bar = state if what is None: return None # if loss if i == self.length: result.insert(Loss()) return result if what[0]: do_offload = what[1] values = self._common_values_enable(state) new_state = self._discretize(*self._new_values(state, do_offload, values)) new_state = (i + 1, *new_state, True) if do_offload: result.insert(Offload(i, input_has_bar)) result.insert(ForwardEnable(i)) result.insert_sequence(self.pofo_rec(new_state)) if do_offload: result.insert(Prefetch(i, input_has_bar)) result.insert(Backward(i)) else: _, do_offload, j = what values = self._common_values_nograd(state, j) new_state = self._discretize(*self._new_values(state, do_offload, values)) new_state = (j + 1, *new_state, False) if do_offload: result.insert(Offload(i, input_has_bar)) result.insert(ForwardCheck(i)) for k in range(i + 1, j + 1): result.insert(ForwardNograd(k)) result.insert_sequence(self.pofo_rec(new_state)) if do_offload: result.insert(Prefetch(i, input_has_bar)) m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]) #TODO: Implement the precise backward recompute sequence mentioned in the paper result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db)) return result def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]): op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) fwd_list = op_list[:op_list.index(loss_op)] bwd_list = op_list[op_list.index(loss_op) + 1:] ckpt_idx = 0 in_ckpt = False ckpt_region = [] # forward annotation for op in fwd_list: if in_ckpt: if isinstance(op, ForwardNograd): ckpt_region.append(op.index) elif isinstance(op, ForwardEnable): in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: setattr(n, "activation_checkpoint", [ckpt_idx]) ckpt_idx += 1 ckpt_region = [] elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: setattr(n, "activation_checkpoint", [ckpt_idx]) ckpt_idx += 1 ckpt_region = [op.index] else: if isinstance(op, ForwardCheck): in_ckpt = True ckpt_region.append(op.index) # annotate the backward if there is any nested activation checkpoint in_recompute = False for op in bwd_list: if in_recompute: if isinstance(op, ForwardNograd): ckpt_region.append(op.index) elif isinstance(op, ForwardEnable): for node_idx in ckpt_region: for n in node_list[node_idx]: n.activation_checkpoint.append(ckpt_idx) ckpt_idx += 1 ckpt_region = [] elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: n.activation_checkpoint.append(ckpt_idx) ckpt_idx += 1 ckpt_region = [op.index] elif isinstance(op, Backward): for node_idx in ckpt_region: for n in node_list[node_idx]: n.activation_checkpoint.append(ckpt_idx) in_recompute = False else: if not isinstance(op, Backward): in_recompute = True ckpt_idx = 0 ckpt_region = [] if isinstance(op, ForwardCheck): ckpt_region.append(op.index) # postprocess, make sure every activation checkpoint label in the # same activation checkpoint region (level = 0) has the same length op_list = [] for node in node_list: op_list += node ckpt_regions = _find_nested_ckpt_regions(op_list) for (start_idx, end_idx) in ckpt_regions: nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) for idx in range(start_idx, end_idx + 1): op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) # annotate the offload offload_idx = 0 for idx, op in enumerate(fwd_list): if isinstance(op, Offload): # corner case: offload input if op.index == 0: if isinstance(fwd_list[idx + 1], ForwardCheck): for n in node_list[op.index]: setattr(n, "activation_offload", True) else: for n in node_list[op.index]: setattr(n, "activation_offload", (offload_idx, True, False)) offload_idx += 1 else: if op.has_bar: # annotate previous node if hasattr(node_list[op.index - 1][0], "activation_offload"): for n in node_list[op.index - 1]: n.activation_offload[-1] = True else: for n in node_list[op.index - 1]: setattr(n, "activation_offload", [offload_idx, False, True]) offload_idx += 1 # annotate this node if isinstance(fwd_list[idx + 1], ForwardCheck): for n in node_list[op.index]: setattr(n, "activation_offload", True) else: for n in node_list[op.index]: setattr(n, "activation_offload", [offload_idx, True, False]) offload_idx += 1 def solver_pofo(gm: ColoGraphModule, data, bandwidth, flops, mem_limit: int, mem_slots: int = 50, cnode: List[str] = None, eps: float = 0.0) -> ColoGraphModule: """Solver that combine offload and activation checkpoint Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html Args: gm (ColoGraphModule): ColoGraphModule derived from tracer data: input of the model bandwidth: offload bandwidth, unit Byte/s flops: FLOPS of device, unit FLOPs/s mem_limit (int): memory limit, unit Byte mem_slots (int, optional): number of memory slots. Defaults to 500. cnode (List[str], optional): common node for linearize. Defaults to None. eps (float, optional): epsilon for memory decay. Defaults to 0.02. Returns: ColoGraphModule: annotated graph module """ node_list = linearize(gm, cnode) mem_limit -= parameter_size(gm) # prepare data if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor data = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data) chain = _normalize_flops(chain, flops) # currently we view loss as an op without expense chain.cbweight.append(0) chain.cweight.append(0) chain.fwd_mem_tmp.append(0) chain.bwd_mem_tmp.append(0) chain.fweight.append(0) chain.bweight.append(0) solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots) first_state = (0, 0, 0, 0, False) sequence = solver.pofo_rec(first_state) if sequence == None: raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory") _annotate_from_pofo_sequence(sequence, node_list) setattr(gm, "__sequence__", sequence) return gm