From e8a9bebc8770b9430f4150a400e6fef43cf02d4f Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Thu, 3 Nov 2022 12:32:51 +0800 Subject: [PATCH] [autoparallel] refactor and add rotorc. (#1789) * [autoparallel] refactor and add rotorc. * [autoparallel] refactor and add rotorc. --- .../auto_parallel/checkpoint/build_c_ext.py | 16 ++ .../checkpoint/ckpt_solver_rotor.c | 197 ++++++++++++++++++ .../checkpoint/ckpt_solver_rotor.py | 162 ++++++++------ .../auto_parallel/checkpoint/operation.py | 83 ++------ colossalai/fx/profiler/profiler.py | 4 + 5 files changed, 333 insertions(+), 129 deletions(-) create mode 100644 colossalai/auto_parallel/checkpoint/build_c_ext.py create mode 100644 colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py new file mode 100644 index 000000000..af4349865 --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -0,0 +1,16 @@ +import os + +from setuptools import Extension, setup + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'rotorc', + sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c new file mode 100644 index 000000000..0fdcfd58a --- /dev/null +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c @@ -0,0 +1,197 @@ +#define PY_SSIZE_T_CLEAN +#include + +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* computeTable(PyObject* self, PyObject* args) { + PyObject* chainParam; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL; + + double* ftime = getDoubleArray(chainParam, "ftime"); + if (!ftime) return NULL; + + double* btime = getDoubleArray(chainParam, "btime"); + if (!btime) return NULL; + + long* x = getLongArray(chainParam, "x"); + if (!x) return NULL; + + long* xbar = getLongArray(chainParam, "xbar"); + if (!xbar) return NULL; + + long* ftmp = getLongArray(chainParam, "btmp"); + if (!ftmp) return NULL; + + long* btmp = getLongArray(chainParam, "btmp"); + if (!btmp) return NULL; + + long chainLength = PyObject_Length(chainParam); + if (!chainLength) return NULL; + +#define COST_TABLE(m, i, l) \ + costTable[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + double* costTable = (double*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double)); + +#define BACK_PTR(m, i, l) \ + backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \ + (i) * (chainLength + 1) + (l)] + long* backPtr = (long*)calloc( + (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chainLength; ++i) + if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && + (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) + COST_TABLE(m, i, i) = ftime[i] + btime[i]; + else + COST_TABLE(m, i, i) = INFINITY; + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chainLength; ++d) { + for (long i = 0; i <= chainLength - d; ++i) { + long idx = i + d; + long mmin = x[idx + 1] + x[i + 1] + ftmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]); + } + mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + for (long j = i + 1; j <= idx; ++j) { + sumFw += ftime[j - 1]; + if (m >= x[j]) { + double cost = sumFw + COST_TABLE(m - x[j], j, idx) + + COST_TABLE(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= xbar[i + 1]) + chainCost = + COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); + if (bestLeafCost <= chainCost) { + COST_TABLE(m, i, idx) = bestLeafCost; + BACK_PTR(m, i, idx) = bestLeaf; + } else { + COST_TABLE(m, i, idx) = chainCost; + BACK_PTR(m, i, idx) = -1; + } + } else + COST_TABLE(m, i, idx) = INFINITY; + } + } + + free(ftime); + free(btime); + free(x); + free(xbar); + free(ftmp); + free(btmp); + + PyObject* pyCostTable = PyList_New(mmax + 1); + PyObject* pyBackPtr = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* pyCostTable_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyCostTable, m, pyCostTable_m); + PyObject* pyBackPtr_m = PyList_New(chainLength + 1); + PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m); + for (long i = 0; i <= chainLength; ++i) { + PyObject* pyCostTable_m_i = PyDict_New(); + PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i); + PyObject* pyBackPtr_m_i = PyDict_New(); + PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i); + for (long l = i; l <= chainLength; ++l) { + PyObject* pyVar_l = PyLong_FromLong(l); + PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l)); + PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); + Py_DECREF(pyCostTable_m_i_l); + PyObject* pyBackPtr_m_i_l; + if (BACK_PTR(m, i, l) < 0) + pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); + else + pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); + PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); + Py_DECREF(pyBackPtr_m_i_l); + Py_DECREF(pyVar_l); + } + } + } + + free(costTable); + free(backPtr); + + PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr); + Py_DECREF(pyCostTable); + Py_DECREF(pyBackPtr); + return result; +} + +static PyMethodDef rotorMethods[] = { + {"compute_table", computeTable, METH_VARARGS, + "Compute the optimal table with the rotor algorithm."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef rotorModule = { + PyModuleDef_HEAD_INIT, "rotorc", /* name of module */ + "A simple implementation of dynamic programming algorithm rotor with C in " + "https://hal.inria.fr/hal-02352969. Some code are adapted from " + "https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be + NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + rotorMethods}; + +PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); } diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index adfb25371..22dbc8be0 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from torch import Tensor from torch.fx import Graph, Node @@ -15,9 +15,9 @@ from colossalai.fx.profiler import ( from colossalai.logging import get_dist_logger from .ckpt_solver_base import CheckpointSolverBase -from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence -__all__ = ['CheckpointSolverBase'] +__all__ = ['CheckpointSolverRotor'] class CheckpointSolverRotor(CheckpointSolverBase): @@ -59,11 +59,12 @@ class CheckpointSolverRotor(CheckpointSolverBase): self.back_ptr = None self.sequence = None - def solve(self, force_python: bool = False) -> Graph: + def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: """Solve the checkpointing problem using rotor algorithm. Args: force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. + verbose (bool, optional): Print verbose information. Defaults to False. Returns: graph (Graph): The optimized graph, should be a copy of the original graph. @@ -76,14 +77,22 @@ class CheckpointSolverRotor(CheckpointSolverBase): else: self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) + if verbose: + self.print_chain() + # backtrack try: - self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr) + self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, + self.back_ptr) self._annotate_from_sequence(self.sequence, self.node_list) - except RuntimeError as e: + except ValueError as e: # using logger to annonce that the solver is failed logger = get_dist_logger() logger.warning(f'Checkpoint solver failed: {e}') + raise ValueError + + if verbose: + self.print_sequence() return deepcopy(self.graph) @@ -100,42 +109,42 @@ class CheckpointSolverRotor(CheckpointSolverBase): @classmethod def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: input_tensors = cls._extract_input(graph) - fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list() + ftime, btime, ftmp, btmp = list(), list(), list(), list() xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] - for idx, node in enumerate(node_list): + for node in node_list: node_info = cls._extract_node_info(node) - fwd_time.append(node_info[0]) - bwd_time.append(node_info[1]) + ftime.append(node_info[0]) + btime.append(node_info[1]) x.append(node_info[2]) xbar.append(node_info[3]) ftmp.append(node_info[4]) btmp.append(node_info[5]) # currently we view loss backward temp as zero - bwd_time.append(0) + btime.append(0) btmp.append(0) - return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp) + return Chain(ftime, btime, x, xbar, ftmp, btmp) @classmethod def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: """Extract node info from a list of nodes""" xbar = 0 - fwd_time = 0 - bwd_time = 0 + ftime = 0 + btime = 0 for n in node: assert isinstance(n, Node), f'{n} is not a Node' xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) # minimum flop count is required - fwd_time += max(calculate_fwd_time(n), 1.0) - bwd_time += max(calculate_bwd_time(n), 1.0) + ftime += max(calculate_fwd_time(n), 1.0) + btime += max(calculate_bwd_time(n), 1.0) x = calculate_fwd_out(node[-1]) xbar = max(x, xbar) ftmp = cls._extract_ftmp(node) btmp = cls._extract_btmp(node) - return fwd_time, bwd_time, x, xbar, ftmp, btmp + return ftime, btime, x, xbar, ftmp, btmp @staticmethod def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: @@ -180,17 +189,17 @@ class CheckpointSolverRotor(CheckpointSolverBase): return btmp @staticmethod - def _compute_table(chain: Chain, mem_slots: int) -> Tuple: + def _compute_table(chain: Chain, mmax: int) -> Tuple: """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. Args: chain (Chain): A basic linearized structure for solving the dynamic programming problem. - mem_slots (int): Number of slots for discretizing memory budget. + mmax (int): Maximum number of memory slots. Returns: - cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length - and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax - back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice + cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length + and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax + back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint of length j """ @@ -203,13 +212,13 @@ class CheckpointSolverRotor(CheckpointSolverBase): btmp = chain.btmp + [0] # Build table - cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)] - back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)] + cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] + back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation # Initialize borders of the tables for lmax-lmin = 0 - for m in range(mem_slots + 1): - for i in range(chain.length + 1): + for m in range(mmax + 1): + for i in range(len(chain) + 1): limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) if m >= limit: # Equation (1) cost_table[m][i][i] = ftime[i] + btime[i] @@ -217,9 +226,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): cost_table[m][i][i] = float("inf") # Compute everything - for m in range(mem_slots + 1): - for d in range(1, chain.length + 1): - for i in range(chain.length + 1 - d): + for m in range(mmax + 1): + for d in range(1, len(chain) + 1): + for i in range(len(chain) + 1 - d): idx = i + d mmin = x[idx + 1] + x[i + 1] + ftmp[i] if idx > i + 1: @@ -248,20 +257,46 @@ class CheckpointSolverRotor(CheckpointSolverBase): return cost_table, back_ptr @staticmethod - def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple: - raise NotImplementedError("C implementation not available yet") + def _compute_table_c(chain: Chain, mmax: int) -> Tuple: + try: + from .rotorc import compute_table - def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]], - back_ptr: List[List[Dict[int, int]]]) -> List[int]: + # build module if module not found + except ModuleNotFoundError: + import os + import subprocess + import sys + logger = get_dist_logger() + logger.info("rotorc hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + [ + f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", + f"--build-lib={this_dir}" + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if result.wait() == 0: + logger.info("rotorc has been built!", ranks=[0]) + from .rotorc import compute_table + else: + logger.warning("rotorc built failed! Using python version!", ranks=[0]) + return CheckpointSolverRotor._compute_table(chain, mmax) + return compute_table(chain, mmax) + + @staticmethod + def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], + back_ptr: List[Any]) -> "Sequence": """Backtrack the cost table and retrieve the optimal checkpointing strategy. Args: chain (Chain): A basic linearized structure for solving the dynamic programming problem. - lmin (int): The left index of the interval to backtrack. - lmax (int): The right index of the interval to backtrack. - mem_budget (int): The memory budget for processing this interval. - cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions - back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions + lhs (int): The left index of the interval to backtrack. + rhs (int): The right index of the interval to backtrack. + budget (int): The memory budget for processing this interval. + cost_table (List[Any]): See `._compute_table()` for definitions + back_ptr (List[Any]): See `._compute_table()` for definitions Raises: ValueError: Can not process the chain. @@ -269,36 +304,45 @@ class CheckpointSolverRotor(CheckpointSolverBase): Returns: sequence (Sequence): The sequence of executing nodes with checkpoints. """ - if mem_budget <= 0: - raise ValueError(f"Can not process a chain with negative memory {mem_budget}") - elif cost_table[mem_budget][lmin][lmax] == float("inf"): - raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}") + if budget <= 0: + raise ValueError(f"Can not process a chain with negative memory {budget}") + elif cost_table[budget][lhs][rhs] == float("inf"): + raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}") - sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget)) - if lmin == lmax: - if lmin == chain.length: - sequence.insert(Loss()) + sequence = Sequence() + if rhs == lhs: + if lhs == len(chain): + sequence += [Loss()] else: - sequence.insert(ForwardEnable(lmin)) - sequence.insert(Backward(lmin)) + sequence += [ForwardEnable(lhs), Backward(lhs)] return sequence - if back_ptr[mem_budget][lmin][lmax][0]: - sequence.insert(ForwardEnable(lmin)) - sequence.insert_sequence( - self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr)) - sequence.insert(Backward(lmin)) + if back_ptr[budget][lhs][rhs][0]: + sequence += [ + ForwardEnable(lhs), + CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, + back_ptr), + Backward(lhs), + ] else: - j = back_ptr[mem_budget][lmin][lmax][1] - sequence.insert(ForwardCheck(lmin)) - for k in range(lmin + 1, j): - sequence.insert(ForwardNograd(k)) - sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr)) - sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr)) + best_leaf = back_ptr[budget][lhs][rhs][1] + sequence += [ForwardCheck(lhs)] + sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] + sequence += [ + CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, + back_ptr), + CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), + ] return sequence @staticmethod def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): + """Annotate the nodes in the node_list with activation checkpoint from the sequence. + + Args: + sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. + node_list (List[List[Node]]): The list of nodes to annotate. + """ op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) fwd_list = op_list[:op_list.index(loss_op)] diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py index cc7172fbc..ab0c6c5ad 100644 --- a/colossalai/auto_parallel/checkpoint/operation.py +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -1,6 +1,6 @@ import math from abc import ABC -from typing import List +from typing import Any, Iterable, List from torch.utils._pytree import tree_map @@ -33,23 +33,25 @@ class Chain: self.xbar = xbar self.ftmp = ftmp self.btmp = btmp - self.length = len(ftime) if check_consistency and not self.check_lengths(): raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): - return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1) - and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length) - and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1)) + return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1)) def __repr__(self): chain_list = [] - for i in range(self.length): + for i in range(len(self)): chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i])) - i = self.length + i = len(self) chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i])) return chain_list.__repr__() + def __len__(self): + return len(self.ftime) + def discretize_all(self, unit: int): """Discretize the chain into a list of chains according to unit size.""" discretizer = lambda val: math.ceil(val / unit) @@ -163,79 +165,20 @@ class DiscardMemory(MemoryAccess): name = "DM" -class Function: +class Sequence(list): - def __init__(self, name, *args): - self.name = name - self.args = args - self.str_args = ','.join(str(v) for v in self.args) - - def __repr__(self): - return "{n}({args})".format(n=self.name, args=self.str_args) - - -class Sequence: - - def __init__(self, function): - self.sequence = [] #List of Operation and Sequence - self.function = function #Description the function (name and parameters) + def __init__(self): + super().__init__() def __repr__(self): return repr(self.list_operations()) def list_operations(self): op_list = [] - for x in self.sequence: + for x in self: if isinstance(x, Operation): op_list.append(x) else: assert isinstance(x, Sequence) op_list += x.list_operations() return op_list - - def insert(self, operation): - self.sequence.append(operation) - - def remove(self, operation_index): - del self.sequence[operation_index] - - def insert_sequence(self, sequence): - self.sequence.append(sequence) - - def shift(self, value): - for x in self.sequence: - x.shift(value) - return self - - def remove_useless_write(self): - if self.sequence: - if isinstance(self.sequence[0], WriteMemory): - self.remove(0) - return self - - def get_makespan(self, chain): - return sum(op.cost(chain) for op in self.list_operations()) - - def without_suffix(self): - ops = self.list_operations() - end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] - try: - last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable) - except ValueError: - last_idx = -1 - if last_idx == end_of_first_phase - 1: - return (self, None) - chain_length = ops[end_of_first_phase - - 1].index ## Some assumption here about the sequence (finishes with Forward_L - start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice - result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain)) - for i in range(last_idx + 1): - result.insert(ops[i]) - result.insert(Loss()) - for i in range(chain_length, start_of_fwd_enable_chain - 1, -1): - position = end_of_first_phase + 1 + (chain_length - i) - assert type(ops[position]) is Backward - assert ops[position].index == i - for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)): - result.insert(ops[i]) - return (result, start_of_fwd_enable_chain) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index dededa410..c87cd4321 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: kwargs['inplace'] = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 do_not_cache = False meta.bwd_mem_out -= param_size @@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: module.inplace = True + meta.bwd_mem_tmp = 0 + meta.bwd_mem_out = 0 do_not_cache = False # grad for param will not be counted