[autoparallel] add rotor C version (#1658)

* [autoparallel] add rotor c version

* [fx] remove metainfoprop in rotor solver

* [autoparallel] modify C
 code format

* [autoparallel] remove build.py

* [autoparallel] fix C extension build

* [autoparallel] add C solver consistency test

* [autoparallel] remove some unused imports

* [autoparallel] refactor rotor solver code

* [autoparallel] replace print with colossalai logger

* [autoparallel] ranks fixed
pull/1678/head
Boyuan Yao 2022-10-03 17:13:30 +08:00 committed by GitHub
parent 11ec070e53
commit 1df98d5b66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 649 additions and 14 deletions

View File

@ -0,0 +1,15 @@
from setuptools import setup, Extension
import os
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'dynamic_programs_C_version',
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)

View File

@ -5,9 +5,8 @@ from colossalai.fx.profiler import activation_size, parameter_size
import math import math
from .linearize import linearize from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai import META_COMPATIBILITY from colossalai.logging import get_dist_logger
# this is the python compute table code from rotor # this is the python compute table code from rotor
@ -323,34 +322,77 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int, mem_limit: int,
mem_slots: int = 500, mem_slots: int = 500,
cnode: List[str] = None, cnode: List[str] = None,
eps: float = 0.0) -> ColoGraphModule: eps: float = 0.0,
force_python: bool = False) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner """solver that automatically find activation checkpoint in rotor's manner
Args: Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model. gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
data (torch.Tensor): input data. data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte. mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None. cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0 eps (float): epsilon for memory decay. Defaults to 0.0
force_python (bool): force to use python version of dynamic programs
Returns: Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
""" """
node_list = linearize(gm, cnode) # try to import C version solver if force_python is not set
mem_unit = mem_limit * (1.0 - eps) // mem_slots logger = get_dist_logger()
if META_COMPATIBILITY: if not force_python:
from colossalai.fx.profiler import MetaTensor try:
data = MetaTensor(data, fake_device=next(gm.parameters()).device) from .dynamic_programs_C_version import persistent_compute_table
MetaInfoProp(gm).run(data) CVERSION = True
# build module if module not found
except ModuleNotFoundError:
import subprocess
import os
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
f'python {os.path.join(this_dir, "build_c_ext.py")} build_ext --build-lib={this_dir}',
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
if result.wait() == 0:
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
from .dynamic_programs_C_version import persistent_compute_table
CVERSION = True
else:
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
CVERSION = False
else:
CVERSION = False
# check if metainfoprop is done
if any(len(node.meta) == 0 for node in gm.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
# linearize the graph
node_list = linearize(gm, cnode)
# construct chain
mem_unit = mem_limit * (1.0 - eps) // mem_slots
chain: Chain = _construct_chain(node_list, data) chain: Chain = _construct_chain(node_list, data)
chain._discretize(mem_unit) chain._discretize(mem_unit)
# use C version if possible
if CVERSION and not force_python:
logger.info("Using C version rotor solver!", ranks=[0])
opt_table = persistent_compute_table(chain, mem_slots)
else:
opt_table = _compute_table(chain, mem_slots) opt_table = _compute_table(chain, mem_slots)
logger.info("Using python version rotor solver!", ranks=[0])
# found sequence
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_list) _annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule # set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence) setattr(gm, "__sequence__", sequence)
gm.recompile()
return gm return gm

View File

@ -0,0 +1,513 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweight");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweight");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweight");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweight");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
if (!cbw) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long* what = (long*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i)
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i) = fw[i] + bw[i];
else
OPT(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
long maxCostFWD = 0;
for (long l = i + 1; l <= chain_length; ++l) {
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
if (l > i + 1) {
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= l; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
if (bestLeafCost <= chainCost) {
OPT(m, i, l) = bestLeafCost;
WHAT(m, i, l) = bestLeaf;
} else {
OPT(m, i, l) = chainCost;
WHAT(m, i, l) = -1;
}
} else
OPT(m, i, l) = INFINITY;
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_what, m, res_what_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
PyObject* res_what_m_i = PyDict_New();
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
PyObject* res_what_m_i_l;
long what_m_i_l = WHAT(m, i, l);
if (what_m_i_l < 0)
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
else
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
Py_DECREF(res_what_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
// long i = L - s, j = t - s, k = l - t
inline long floating_index_in_array(long m_factor, long m, long i, long j,
long k) {
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
(j * (j - 1)) / 2 + k;
}
typedef struct {
long sp;
long r;
long tp;
} index_t;
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
if (!fwd_tmp) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
if (!bwd_tmp) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
const long m_factor =
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
partialSumsFW[i] = total;
total += fw[i];
}
partialSumsFW[chain_length] = total;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i, i) = fw[i] + bw[i];
else
OPT(m, i, i, i) = INFINITY;
}
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) { // d = l - s
for (long s = 0; s <= chain_length - d; ++s) {
long l = s + d;
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
long memNullSecond = 0;
for (long j = s + 1; j < l; ++j) {
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
if (val > memNullSecond) memNullSecond = val;
}
for (long t = s; t <= l; ++t) {
double chainCost = INFINITY;
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
}
double bestLeafCost = INFINITY;
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
for (long r = s; r <= t; ++r)
if (cw[s] <= cw[r])
for (long tp = t + 1; tp <= l; ++tp)
for (long sp = r + 1; sp <= tp; ++sp) {
long mp = m - cw[r] + cw[s];
assert(mp >= 0);
if (mp >= cw[sp]) {
double value = partialSumsFW[sp] - partialSumsFW[s] +
OPT(mp - cw[sp], sp, tp, l) +
OPT(mp, r, t, tp - 1);
if (value < bestLeafCost) {
bestLeafCost = value;
bestLeaf.sp = sp;
bestLeaf.r = r;
bestLeaf.tp = tp;
}
}
}
}
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
OPT(m, s, t, l) = bestLeafCost;
WHAT(m, s, t, l).sp = bestLeaf.sp;
WHAT(m, s, t, l).r = bestLeaf.r;
WHAT(m, s, t, l).tp = bestLeaf.tp;
} else {
OPT(m, s, t, l) = chainCost;
WHAT(m, s, t, l).sp = -1;
}
}
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyDict_New();
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyDict_New();
PyList_SET_ITEM(res_what, m, res_what_m);
for (long s = 0; s <= chain_length; ++s)
for (long t = s; t <= chain_length; ++t)
for (long l = t; l <= chain_length; ++l) {
PyObject* key = Py_BuildValue("(lll)", s, t, l);
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
PyDict_SetItem(res_opt_m, key, value_opt);
PyObject* value_what = true_tuple;
index_t* idx_what = &WHAT(m, s, t, l);
if (idx_what->sp >= 0)
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
idx_what->r, idx_what->tp);
PyDict_SetItem(res_what_m, key, value_what);
if (value_what != true_tuple) Py_DECREF(value_what);
Py_DECREF(key);
Py_DECREF(value_opt);
}
}
Py_DECREF(true_tuple);
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
// Compute partial sums
double* sumfw = (double*)calloc(chain_length, sizeof(double));
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
total += fw[i];
sumfw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length + 1; ++i) {
total += bw[i];
sumbw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length; ++i) {
total += sumfw[i];
sumsumfw[i] = total;
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
OPT(m, i, i) = bw[i];
else
OPT(m, i, i) = INFINITY;
if (i < chain_length) {
long maxC = fmaxl(cw[i], cw[i + 1]);
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
else
OPT(m, i, i + 1) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i + 2 <= chain_length; ++i) {
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
for (long l = i + 2; l <= chain_length; ++l) {
maxCW_il = fmax(maxCW_il, cw[l + 1]);
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
long mmin = fmaxl(mminCst, maxCostFWD);
if ((m >= mmin)) {
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
noCheckpointCost +=
sumsumfw[l - 1] -
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
double valueCost = INFINITY;
if (m >= cw[i]) {
double sumFwds = 0;
for (long j = i + 1; j < l; ++j) {
sumFwds += fw[j - 1];
valueCost = fmin(
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
}
}
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
} else
OPT(m, i, l) = INFINITY;
}
}
free(sumfw);
free(sumbw);
free(sumsumfw);
free(fw);
free(bw);
free(cw);
free(cbw);
PyObject* res_opt = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l - i);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
return res_opt;
}
static PyMethodDef dynamic_programs_methods[] = {
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
"Compute the optimal table with the persistent algorithm."},
{"floating_compute_table", floating_compute_table, METH_VARARGS,
"Compute the optimal table with the floating algorithm."},
{"griewank_heterogeneous_compute_table",
griewank_heterogeneous_compute_table, METH_VARARGS,
"Compute the optimal table for the Griewank Heterogeneous Model."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef dynamic_programs_module = {
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods};
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
return PyModule_Create(&dynamic_programs_module);
}

View File

@ -1,7 +1,7 @@
import os import os
import subprocess import subprocess
import re import re
from setuptools import find_packages, setup from setuptools import find_packages, setup, Extension
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))

View File

@ -0,0 +1,65 @@
import copy
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
import torch.fx
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
withcodegen = True
except:
from colossalai.fx.codegen import python_code_with_activation_checkpoint
withcodegen = False
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
model = M()
data = torch.rand(128, 3, 224, 224, device='meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY:
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta)
# python solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
# C solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
sequence_python = sequence_python.list_operations()
sequence_C = sequence_C.list_operations()
# make sure the solutions are the same
assert len(sequence_python) == len(sequence_C) and \
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))
gpc.destroy()
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
if __name__ == '__main__':
_run_C_solver_consistency_test(rank=0)