From 1df98d5b66d7c99ca3055db5fdf7e859f2c0652a Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Mon, 3 Oct 2022 17:13:30 +0800 Subject: [PATCH] [autoparallel] add rotor C version (#1658) * [autoparallel] add rotor c version * [fx] remove metainfoprop in rotor solver * [autoparallel] modify C code format * [autoparallel] remove build.py * [autoparallel] fix C extension build * [autoparallel] add C solver consistency test * [autoparallel] remove some unused imports * [autoparallel] refactor rotor solver code * [autoparallel] replace print with colossalai logger * [autoparallel] ranks fixed --- .../fx/passes/algorithms/build_c_ext.py | 15 + .../fx/passes/algorithms/ckpt_solver_rotor.py | 64 ++- .../fx/passes/algorithms/dynamic_programs.c | 513 ++++++++++++++++++ setup.py | 6 +- .../test_C_solver_consistency.py | 65 +++ 5 files changed, 649 insertions(+), 14 deletions(-) create mode 100644 colossalai/fx/passes/algorithms/build_c_ext.py create mode 100644 colossalai/fx/passes/algorithms/dynamic_programs.c create mode 100644 tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py diff --git a/colossalai/fx/passes/algorithms/build_c_ext.py b/colossalai/fx/passes/algorithms/build_c_ext.py new file mode 100644 index 000000000..cb360cb20 --- /dev/null +++ b/colossalai/fx/passes/algorithms/build_c_ext.py @@ -0,0 +1,15 @@ +from setuptools import setup, Extension +import os + +this_dir = os.path.dirname(os.path.abspath(__file__)) +ext_modules = [Extension( + 'dynamic_programs_C_version', + sources=[os.path.join(this_dir, 'dynamic_programs.c')], +)] + +setup( + name='rotor c extension', + version='0.1', + description='rotor c extension for faster dp computing', + ext_modules=ext_modules, +) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 0aed73151..6c6d9dd10 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -5,9 +5,8 @@ from colossalai.fx.profiler import activation_size, parameter_size import math from .linearize import linearize from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function -from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions -from colossalai import META_COMPATIBILITY +from colossalai.logging import get_dist_logger # this is the python compute table code from rotor @@ -323,34 +322,77 @@ def solver_rotor(gm: ColoGraphModule, mem_limit: int, mem_slots: int = 500, cnode: List[str] = None, - eps: float = 0.0) -> ColoGraphModule: + eps: float = 0.0, + force_python: bool = False) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: - gm (ColoGraphModule): ColoGraphModule generated by tracing model. + gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp. data (torch.Tensor): input data. mem_limit (int): memory budget in Byte. mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. cnode (List[Node], optional): common node list for linearize. Defaults to None. eps (float): epsilon for memory decay. Defaults to 0.0 + force_python (bool): force to use python version of dynamic programs Returns: ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute """ - node_list = linearize(gm, cnode) - mem_unit = mem_limit * (1.0 - eps) // mem_slots - if META_COMPATIBILITY: - from colossalai.fx.profiler import MetaTensor - data = MetaTensor(data, fake_device=next(gm.parameters()).device) - MetaInfoProp(gm).run(data) + # try to import C version solver if force_python is not set + logger = get_dist_logger() + if not force_python: + try: + from .dynamic_programs_C_version import persistent_compute_table + CVERSION = True + # build module if module not found + except ModuleNotFoundError: + import subprocess + import os + logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = subprocess.Popen( + f'python {os.path.join(this_dir, "build_c_ext.py")} build_ext --build-lib={this_dir}', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + if result.wait() == 0: + logger.info("dynamic_programs_C_version has been built!", ranks=[0]) + from .dynamic_programs_C_version import persistent_compute_table + CVERSION = True + else: + logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0]) + CVERSION = False + else: + CVERSION = False + + # check if metainfoprop is done + if any(len(node.meta) == 0 for node in gm.graph.nodes): + raise RuntimeError( + "Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!") + + # linearize the graph + node_list = linearize(gm, cnode) + + # construct chain + mem_unit = mem_limit * (1.0 - eps) // mem_slots chain: Chain = _construct_chain(node_list, data) chain._discretize(mem_unit) - opt_table = _compute_table(chain, mem_slots) + + # use C version if possible + if CVERSION and not force_python: + logger.info("Using C version rotor solver!", ranks=[0]) + opt_table = persistent_compute_table(chain, mem_slots) + else: + opt_table = _compute_table(chain, mem_slots) + logger.info("Using python version rotor solver!", ranks=[0]) + + # found sequence sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_list) # set __sequence__ attribute to GraphModule setattr(gm, "__sequence__", sequence) + gm.recompile() return gm diff --git a/colossalai/fx/passes/algorithms/dynamic_programs.c b/colossalai/fx/passes/algorithms/dynamic_programs.c new file mode 100644 index 000000000..4bea393a7 --- /dev/null +++ b/colossalai/fx/passes/algorithms/dynamic_programs.c @@ -0,0 +1,513 @@ +#define PY_SSIZE_T_CLEAN +#include + +long* PySequenceToLongArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + long* result = (long*)calloc(len + 1, sizeof(long)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyLong_AsLong(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +double* PySequenceToDoubleArray(PyObject* pylist) { + if (!(pylist && PySequence_Check(pylist))) return NULL; + Py_ssize_t len = PySequence_Size(pylist); + double* result = (double*)calloc(len + 1, sizeof(double)); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* item = PySequence_GetItem(pylist, i); + result[i] = PyFloat_AsDouble(item); + Py_DECREF(item); + } + result[len] = 0; + return result; +} + +long* getLongArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + long* result = PySequenceToLongArray(sequence); + Py_DECREF(sequence); + return result; +} + +double* getDoubleArray(PyObject* container, const char* attributeName) { + PyObject* sequence = PyObject_GetAttrString(container, attributeName); + double* result = PySequenceToDoubleArray(sequence); + Py_DECREF(sequence); + return result; +} + +static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweight"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweight"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweight"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweight"); + if (!cbw) return NULL; + + long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp"); + if (!cbw) return NULL; + + long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp"); + if (!cbw) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + // TODO: Can be optimized by only allocating memory for l >= i + // TODO: float / int instead of double / long ? +#define OPT(m, i, l) \ + opt[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + double* opt = (double*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); + +#define WHAT(m, i, l) \ + what[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + long* what = (long*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long)); + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && + (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) + OPT(m, i, i) = fw[i] + bw[i]; + else + OPT(m, i, i) = INFINITY; + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) { + long maxCostFWD = 0; + for (long l = i + 1; l <= chain_length; ++l) { + long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i]; + if (l > i + 1) { + maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]); + mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD); + } + if ((m >= mmin)) { + long bestLeaf = -1; + double sumFw = 0; + double bestLeafCost = INFINITY; + /// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j = + /// i+1 + for (long j = i + 1; j <= l; ++j) { + sumFw += fw[j - 1]; + if (m >= cw[j]) { + double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1); + if (cost < bestLeafCost) { + bestLeafCost = cost; + bestLeaf = j; + } + } + } + double chainCost = INFINITY; + if (m >= cbw[i + 1]) + chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l); + if (bestLeafCost <= chainCost) { + OPT(m, i, l) = bestLeafCost; + WHAT(m, i, l) = bestLeaf; + } else { + OPT(m, i, l) = chainCost; + WHAT(m, i, l) = -1; + } + } else + OPT(m, i, l) = INFINITY; + } + } + + free(fw); + free(bw); + free(cw); + free(cbw); + free(fwd_tmp); + free(bwd_tmp); + + PyObject* res_opt = PyList_New(mmax + 1); + PyObject* res_what = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_opt, m, res_opt_m); + PyObject* res_what_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_what, m, res_what_m); + for (long i = 0; i <= chain_length; ++i) { + PyObject* res_opt_m_i = PyDict_New(); + PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); + PyObject* res_what_m_i = PyDict_New(); + PyList_SET_ITEM(res_what_m, i, res_what_m_i); + for (long l = i; l <= chain_length; ++l) { + PyObject* res_l = PyLong_FromLong(l); + PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); + PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); + Py_DECREF(res_opt_m_i_l); + PyObject* res_what_m_i_l; + long what_m_i_l = WHAT(m, i, l); + if (what_m_i_l < 0) + res_what_m_i_l = Py_BuildValue("(O)", Py_True); + else + res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l); + PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l); + Py_DECREF(res_what_m_i_l); + Py_DECREF(res_l); + } + } + } + + free(opt); + free(what); + + PyObject* result = PyTuple_Pack(2, res_opt, res_what); + Py_DECREF(res_opt); + Py_DECREF(res_what); + return result; +} + +// long i = L - s, j = t - s, k = l - t +inline long floating_index_in_array(long m_factor, long m, long i, long j, + long k) { + return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j - + (j * (j - 1)) / 2 + k; +} + +typedef struct { + long sp; + long r; + long tp; +} index_t; + +static PyObject* floating_compute_table(PyObject* self, PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweigth"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweigth"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweigth"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweigth"); + if (!cbw) return NULL; + + long* fwd_tmp = getLongArray(chain_param, "fwd_tmp"); + if (!fwd_tmp) return NULL; + + long* bwd_tmp = getLongArray(chain_param, "bwd_tmp"); + if (!bwd_tmp) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + const long m_factor = + (chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12; + + // Defined for 0 <= s <= t <= l <= chain_length, for all m +#undef OPT +#define OPT(m, s, t, l) \ + opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ + (l) - (t))] + double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double)); + +#undef WHAT +#define WHAT(m, s, t, l) \ + what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \ + (l) - (t))] + index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t)); + + double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double)); + double total = 0; + for (long i = 0; i < chain_length; ++i) { + partialSumsFW[i] = total; + total += fw[i]; + } + partialSumsFW[chain_length] = total; + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) { + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) && + (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i])) + OPT(m, i, i, i) = fw[i] + bw[i]; + else + OPT(m, i, i, i) = INFINITY; + } + + for (long m = 0; m <= mmax; ++m) + for (long d = 1; d <= chain_length; ++d) { // d = l - s + for (long s = 0; s <= chain_length - d; ++s) { + long l = s + d; + long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s]; + long memNullSecond = 0; + for (long j = s + 1; j < l; ++j) { + long val = cw[j] + cw[j + 1] + fwd_tmp[j]; + if (val > memNullSecond) memNullSecond = val; + } + for (long t = s; t <= l; ++t) { + double chainCost = INFINITY; + if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) && + (m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) { + chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l); + } + double bestLeafCost = INFINITY; + index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1}; + if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) { + for (long r = s; r <= t; ++r) + if (cw[s] <= cw[r]) + for (long tp = t + 1; tp <= l; ++tp) + for (long sp = r + 1; sp <= tp; ++sp) { + long mp = m - cw[r] + cw[s]; + assert(mp >= 0); + if (mp >= cw[sp]) { + double value = partialSumsFW[sp] - partialSumsFW[s] + + OPT(mp - cw[sp], sp, tp, l) + + OPT(mp, r, t, tp - 1); + if (value < bestLeafCost) { + bestLeafCost = value; + bestLeaf.sp = sp; + bestLeaf.r = r; + bestLeaf.tp = tp; + } + } + } + } + if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) { + OPT(m, s, t, l) = bestLeafCost; + WHAT(m, s, t, l).sp = bestLeaf.sp; + WHAT(m, s, t, l).r = bestLeaf.r; + WHAT(m, s, t, l).tp = bestLeaf.tp; + } else { + OPT(m, s, t, l) = chainCost; + WHAT(m, s, t, l).sp = -1; + } + } + } + } + + free(fw); + free(bw); + free(cw); + free(cbw); + free(fwd_tmp); + free(bwd_tmp); + + PyObject* res_opt = PyList_New(mmax + 1); + PyObject* res_what = PyList_New(mmax + 1); + + // Convert the result into Python world + PyObject* true_tuple = Py_BuildValue("(O)", Py_True); + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyDict_New(); + PyList_SET_ITEM(res_opt, m, res_opt_m); + PyObject* res_what_m = PyDict_New(); + PyList_SET_ITEM(res_what, m, res_what_m); + for (long s = 0; s <= chain_length; ++s) + for (long t = s; t <= chain_length; ++t) + for (long l = t; l <= chain_length; ++l) { + PyObject* key = Py_BuildValue("(lll)", s, t, l); + PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l)); + PyDict_SetItem(res_opt_m, key, value_opt); + PyObject* value_what = true_tuple; + index_t* idx_what = &WHAT(m, s, t, l); + if (idx_what->sp >= 0) + value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp, + idx_what->r, idx_what->tp); + PyDict_SetItem(res_what_m, key, value_what); + if (value_what != true_tuple) Py_DECREF(value_what); + Py_DECREF(key); + Py_DECREF(value_opt); + } + } + + Py_DECREF(true_tuple); + + free(opt); + free(what); + + PyObject* result = PyTuple_Pack(2, res_opt, res_what); + Py_DECREF(res_opt); + Py_DECREF(res_what); + return result; +} + +static PyObject* griewank_heterogeneous_compute_table(PyObject* self, + PyObject* args) { + PyObject* chain_param; + int mmax; + + if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL; + + double* fw = getDoubleArray(chain_param, "fweigth"); + if (!fw) return NULL; + + double* bw = getDoubleArray(chain_param, "bweigth"); + if (!bw) return NULL; + + long* cw = getLongArray(chain_param, "cweigth"); + if (!cw) return NULL; + + long* cbw = getLongArray(chain_param, "cbweigth"); + if (!cbw) return NULL; + + PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length"); + if (!chain_length_param) return NULL; + long chain_length = PyLong_AsLong(chain_length_param); + Py_DECREF(chain_length_param); + + // TODO: Can be optimized by only allocating memory for l >= i + // TODO: float / int instead of double / long ? +#undef OPT +#define OPT(m, i, l) \ + opt[(m) * (chain_length + 1) * (chain_length + 1) + \ + (i) * (chain_length + 1) + (l)] + double* opt = (double*)calloc( + (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double)); + + // Compute partial sums + double* sumfw = (double*)calloc(chain_length, sizeof(double)); + double* sumbw = (double*)calloc(chain_length + 1, sizeof(double)); + double* sumsumfw = (double*)calloc(chain_length, sizeof(double)); + + double total = 0; + for (long i = 0; i < chain_length; ++i) { + total += fw[i]; + sumfw[i] = total; + } + + total = 0; + for (long i = 0; i < chain_length + 1; ++i) { + total += bw[i]; + sumbw[i] = total; + } + + total = 0; + for (long i = 0; i < chain_length; ++i) { + total += sumfw[i]; + sumsumfw[i] = total; + } + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i <= chain_length; ++i) { + // TODO: Can be optimized to remove the IF by reordering loops + if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1])) + OPT(m, i, i) = bw[i]; + else + OPT(m, i, i) = INFINITY; + + if (i < chain_length) { + long maxC = fmaxl(cw[i], cw[i + 1]); + long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC); + if ((m >= cbw[i]) && (m >= cw[i] + maxCB)) + OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1]; + else + OPT(m, i, i + 1) = INFINITY; + } + } + + for (long m = 0; m <= mmax; ++m) + for (long i = 0; i + 2 <= chain_length; ++i) { + long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]); + long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]); + long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il; + for (long l = i + 2; l <= chain_length; ++l) { + maxCW_il = fmax(maxCW_il, cw[l + 1]); + maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il); + long mmin = fmaxl(mminCst, maxCostFWD); + if ((m >= mmin)) { + double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0); + noCheckpointCost += + sumsumfw[l - 1] - + (i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0); + + double valueCost = INFINITY; + if (m >= cw[i]) { + double sumFwds = 0; + for (long j = i + 1; j < l; ++j) { + sumFwds += fw[j - 1]; + valueCost = fmin( + valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1)); + } + } + OPT(m, i, l) = fmin(noCheckpointCost, valueCost); + } else + OPT(m, i, l) = INFINITY; + } + } + + free(sumfw); + free(sumbw); + free(sumsumfw); + free(fw); + free(bw); + free(cw); + free(cbw); + + PyObject* res_opt = PyList_New(mmax + 1); + + // Convert the result into Python world + for (long m = 0; m <= mmax; ++m) { + PyObject* res_opt_m = PyList_New(chain_length + 1); + PyList_SET_ITEM(res_opt, m, res_opt_m); + for (long i = 0; i <= chain_length; ++i) { + PyObject* res_opt_m_i = PyDict_New(); + PyList_SET_ITEM(res_opt_m, i, res_opt_m_i); + for (long l = i; l <= chain_length; ++l) { + PyObject* res_l = PyLong_FromLong(l - i); + PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l)); + PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l); + Py_DECREF(res_opt_m_i_l); + Py_DECREF(res_l); + } + } + } + + free(opt); + + return res_opt; +} + +static PyMethodDef dynamic_programs_methods[] = { + {"persistent_compute_table", persistent_compute_table, METH_VARARGS, + "Compute the optimal table with the persistent algorithm."}, + {"floating_compute_table", floating_compute_table, METH_VARARGS, + "Compute the optimal table with the floating algorithm."}, + {"griewank_heterogeneous_compute_table", + griewank_heterogeneous_compute_table, METH_VARARGS, + "Compute the optimal table for the Griewank Heterogeneous Model."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef dynamic_programs_module = { + PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + dynamic_programs_methods}; + +PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) { + return PyModule_Create(&dynamic_programs_module); +} diff --git a/setup.py b/setup.py index b58a4d989..8341a97b7 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import os import subprocess import re -from setuptools import find_packages, setup +from setuptools import find_packages, setup, Extension # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -100,7 +100,7 @@ def get_version(): version += f'+torch{torch_version}cu{cuda_version}' return version - + if build_cuda_ext: try: import torch @@ -115,7 +115,7 @@ if build_cuda_ext: except ImportError: print('torch is not found. CUDA extension will not be installed') build_cuda_ext = False - + if build_cuda_ext: build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) diff --git a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py new file mode 100644 index 000000000..c638e7ac2 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py @@ -0,0 +1,65 @@ +import copy +import torch +import torch.multiprocessing as mp +import torchvision.models as tm +import torch.fx +import colossalai +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx.passes.algorithms import solver_rotor +from colossalai.fx.passes.algorithms.operation import Sequence +from colossalai.core import global_context as gpc +from colossalai.utils import free_port +import pytest +from colossalai import META_COMPATIBILITY +if META_COMPATIBILITY: + from colossalai.fx.profiler.tensor import MetaTensor + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True +except: + from colossalai.fx.codegen import python_code_with_activation_checkpoint + withcodegen = False + + +def _run_C_solver_consistency_test(rank=0): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]: + model = M() + data = torch.rand(128, 3, 224, 224, device='meta') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": data}) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + if META_COMPATIBILITY: + data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) + MetaInfoProp(gm).run(data_meta) + + # python solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True) + sequence_python: Sequence = copy.deepcopy(gm.__sequence__) + + # C solver + gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024) + sequence_C: Sequence = copy.deepcopy(gm.__sequence__) + + sequence_python = sequence_python.list_operations() + sequence_C = sequence_C.list_operations() + + # make sure the solutions are the same + assert len(sequence_python) == len(sequence_C) and \ + all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + + gpc.destroy() + + +@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") +def test_C_solver_consistency(): + mp.spawn(_run_C_solver_consistency_test, nprocs=1) + + +if __name__ == '__main__': + _run_C_solver_consistency_test(rank=0)