mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add rotor C version (#1658)
* [autoparallel] add rotor c version * [fx] remove metainfoprop in rotor solver * [autoparallel] modify C code format * [autoparallel] remove build.py * [autoparallel] fix C extension build * [autoparallel] add C solver consistency test * [autoparallel] remove some unused imports * [autoparallel] refactor rotor solver code * [autoparallel] replace print with colossalai logger * [autoparallel] ranks fixedpull/1678/head
parent
11ec070e53
commit
1df98d5b66
|
@ -0,0 +1,15 @@
|
|||
from setuptools import setup, Extension
|
||||
import os
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_modules = [Extension(
|
||||
'dynamic_programs_C_version',
|
||||
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
|
||||
)]
|
||||
|
||||
setup(
|
||||
name='rotor c extension',
|
||||
version='0.1',
|
||||
description='rotor c extension for faster dp computing',
|
||||
ext_modules=ext_modules,
|
||||
)
|
|
@ -5,9 +5,8 @@ from colossalai.fx.profiler import activation_size, parameter_size
|
|||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
|
@ -323,34 +322,77 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
mem_limit: int,
|
||||
mem_slots: int = 500,
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.0) -> ColoGraphModule:
|
||||
eps: float = 0.0,
|
||||
force_python: bool = False) -> ColoGraphModule:
|
||||
"""solver that automatically find activation checkpoint in rotor's manner
|
||||
|
||||
Args:
|
||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
|
||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
|
||||
data (torch.Tensor): input data.
|
||||
mem_limit (int): memory budget in Byte.
|
||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||
eps (float): epsilon for memory decay. Defaults to 0.0
|
||||
force_python (bool): force to use python version of dynamic programs
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
# try to import C version solver if force_python is not set
|
||||
logger = get_dist_logger()
|
||||
if not force_python:
|
||||
try:
|
||||
from .dynamic_programs_C_version import persistent_compute_table
|
||||
CVERSION = True
|
||||
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import subprocess
|
||||
import os
|
||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
f'python {os.path.join(this_dir, "build_c_ext.py")} build_ext --build-lib={this_dir}',
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=True)
|
||||
if result.wait() == 0:
|
||||
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
|
||||
from .dynamic_programs_C_version import persistent_compute_table
|
||||
CVERSION = True
|
||||
else:
|
||||
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
|
||||
CVERSION = False
|
||||
else:
|
||||
CVERSION = False
|
||||
|
||||
# check if metainfoprop is done
|
||||
if any(len(node.meta) == 0 for node in gm.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
|
||||
|
||||
# linearize the graph
|
||||
node_list = linearize(gm, cnode)
|
||||
|
||||
# construct chain
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain._discretize(mem_unit)
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
|
||||
# use C version if possible
|
||||
if CVERSION and not force_python:
|
||||
logger.info("Using C version rotor solver!", ranks=[0])
|
||||
opt_table = persistent_compute_table(chain, mem_slots)
|
||||
else:
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
logger.info("Using python version rotor solver!", ranks=[0])
|
||||
|
||||
# found sequence
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
|
|
@ -0,0 +1,513 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
long* PySequenceToLongArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
long* result = (long*)calloc(len + 1, sizeof(long));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyLong_AsLong(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
double* PySequenceToDoubleArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
double* result = (double*)calloc(len + 1, sizeof(double));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyFloat_AsDouble(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
long* getLongArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
long* result = PySequenceToLongArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
double* result = PySequenceToDoubleArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweight");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweight");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweight");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweight");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
// TODO: Can be optimized by only allocating memory for l >= i
|
||||
// TODO: float / int instead of double / long ?
|
||||
#define OPT(m, i, l) \
|
||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
double* opt = (double*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
||||
|
||||
#define WHAT(m, i, l) \
|
||||
what[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
long* what = (long*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i)
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
||||
OPT(m, i, i) = fw[i] + bw[i];
|
||||
else
|
||||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
long maxCostFWD = 0;
|
||||
for (long l = i + 1; l <= chain_length; ++l) {
|
||||
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (l > i + 1) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
|
||||
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
double sumFw = 0;
|
||||
double bestLeafCost = INFINITY;
|
||||
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
|
||||
/// i+1
|
||||
for (long j = i + 1; j <= l; ++j) {
|
||||
sumFw += fw[j - 1];
|
||||
if (m >= cw[j]) {
|
||||
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= cbw[i + 1])
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
OPT(m, i, l) = bestLeafCost;
|
||||
WHAT(m, i, l) = bestLeaf;
|
||||
} else {
|
||||
OPT(m, i, l) = chainCost;
|
||||
WHAT(m, i, l) = -1;
|
||||
}
|
||||
} else
|
||||
OPT(m, i, l) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
free(fwd_tmp);
|
||||
free(bwd_tmp);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
PyObject* res_what = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
PyObject* res_what_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
PyObject* res_opt_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
||||
PyObject* res_what_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
|
||||
for (long l = i; l <= chain_length; ++l) {
|
||||
PyObject* res_l = PyLong_FromLong(l);
|
||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
||||
Py_DECREF(res_opt_m_i_l);
|
||||
PyObject* res_what_m_i_l;
|
||||
long what_m_i_l = WHAT(m, i, l);
|
||||
if (what_m_i_l < 0)
|
||||
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||
else
|
||||
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
|
||||
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
|
||||
Py_DECREF(res_what_m_i_l);
|
||||
Py_DECREF(res_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(opt);
|
||||
free(what);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
||||
Py_DECREF(res_opt);
|
||||
Py_DECREF(res_what);
|
||||
return result;
|
||||
}
|
||||
|
||||
// long i = L - s, j = t - s, k = l - t
|
||||
inline long floating_index_in_array(long m_factor, long m, long i, long j,
|
||||
long k) {
|
||||
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
|
||||
(j * (j - 1)) / 2 + k;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
long sp;
|
||||
long r;
|
||||
long tp;
|
||||
} index_t;
|
||||
|
||||
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweigth");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
|
||||
if (!fwd_tmp) return NULL;
|
||||
|
||||
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
|
||||
if (!bwd_tmp) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
const long m_factor =
|
||||
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
|
||||
|
||||
// Defined for 0 <= s <= t <= l <= chain_length, for all m
|
||||
#undef OPT
|
||||
#define OPT(m, s, t, l) \
|
||||
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
||||
(l) - (t))]
|
||||
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
|
||||
|
||||
#undef WHAT
|
||||
#define WHAT(m, s, t, l) \
|
||||
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
|
||||
(l) - (t))]
|
||||
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
|
||||
|
||||
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
|
||||
double total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
partialSumsFW[i] = total;
|
||||
total += fw[i];
|
||||
}
|
||||
partialSumsFW[chain_length] = total;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
|
||||
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
|
||||
OPT(m, i, i, i) = fw[i] + bw[i];
|
||||
else
|
||||
OPT(m, i, i, i) = INFINITY;
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long d = 1; d <= chain_length; ++d) { // d = l - s
|
||||
for (long s = 0; s <= chain_length - d; ++s) {
|
||||
long l = s + d;
|
||||
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
|
||||
long memNullSecond = 0;
|
||||
for (long j = s + 1; j < l; ++j) {
|
||||
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
|
||||
if (val > memNullSecond) memNullSecond = val;
|
||||
}
|
||||
for (long t = s; t <= l; ++t) {
|
||||
double chainCost = INFINITY;
|
||||
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
|
||||
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
|
||||
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
|
||||
}
|
||||
double bestLeafCost = INFINITY;
|
||||
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
|
||||
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
|
||||
for (long r = s; r <= t; ++r)
|
||||
if (cw[s] <= cw[r])
|
||||
for (long tp = t + 1; tp <= l; ++tp)
|
||||
for (long sp = r + 1; sp <= tp; ++sp) {
|
||||
long mp = m - cw[r] + cw[s];
|
||||
assert(mp >= 0);
|
||||
if (mp >= cw[sp]) {
|
||||
double value = partialSumsFW[sp] - partialSumsFW[s] +
|
||||
OPT(mp - cw[sp], sp, tp, l) +
|
||||
OPT(mp, r, t, tp - 1);
|
||||
if (value < bestLeafCost) {
|
||||
bestLeafCost = value;
|
||||
bestLeaf.sp = sp;
|
||||
bestLeaf.r = r;
|
||||
bestLeaf.tp = tp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
|
||||
OPT(m, s, t, l) = bestLeafCost;
|
||||
WHAT(m, s, t, l).sp = bestLeaf.sp;
|
||||
WHAT(m, s, t, l).r = bestLeaf.r;
|
||||
WHAT(m, s, t, l).tp = bestLeaf.tp;
|
||||
} else {
|
||||
OPT(m, s, t, l) = chainCost;
|
||||
WHAT(m, s, t, l).sp = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
free(fwd_tmp);
|
||||
free(bwd_tmp);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
PyObject* res_what = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
PyObject* res_what_m = PyDict_New();
|
||||
PyList_SET_ITEM(res_what, m, res_what_m);
|
||||
for (long s = 0; s <= chain_length; ++s)
|
||||
for (long t = s; t <= chain_length; ++t)
|
||||
for (long l = t; l <= chain_length; ++l) {
|
||||
PyObject* key = Py_BuildValue("(lll)", s, t, l);
|
||||
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
|
||||
PyDict_SetItem(res_opt_m, key, value_opt);
|
||||
PyObject* value_what = true_tuple;
|
||||
index_t* idx_what = &WHAT(m, s, t, l);
|
||||
if (idx_what->sp >= 0)
|
||||
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
|
||||
idx_what->r, idx_what->tp);
|
||||
PyDict_SetItem(res_what_m, key, value_what);
|
||||
if (value_what != true_tuple) Py_DECREF(value_what);
|
||||
Py_DECREF(key);
|
||||
Py_DECREF(value_opt);
|
||||
}
|
||||
}
|
||||
|
||||
Py_DECREF(true_tuple);
|
||||
|
||||
free(opt);
|
||||
free(what);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
|
||||
Py_DECREF(res_opt);
|
||||
Py_DECREF(res_what);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
|
||||
PyObject* args) {
|
||||
PyObject* chain_param;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
|
||||
|
||||
double* fw = getDoubleArray(chain_param, "fweigth");
|
||||
if (!fw) return NULL;
|
||||
|
||||
double* bw = getDoubleArray(chain_param, "bweigth");
|
||||
if (!bw) return NULL;
|
||||
|
||||
long* cw = getLongArray(chain_param, "cweigth");
|
||||
if (!cw) return NULL;
|
||||
|
||||
long* cbw = getLongArray(chain_param, "cbweigth");
|
||||
if (!cbw) return NULL;
|
||||
|
||||
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
|
||||
if (!chain_length_param) return NULL;
|
||||
long chain_length = PyLong_AsLong(chain_length_param);
|
||||
Py_DECREF(chain_length_param);
|
||||
|
||||
// TODO: Can be optimized by only allocating memory for l >= i
|
||||
// TODO: float / int instead of double / long ?
|
||||
#undef OPT
|
||||
#define OPT(m, i, l) \
|
||||
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
|
||||
(i) * (chain_length + 1) + (l)]
|
||||
double* opt = (double*)calloc(
|
||||
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
|
||||
|
||||
// Compute partial sums
|
||||
double* sumfw = (double*)calloc(chain_length, sizeof(double));
|
||||
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
|
||||
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
|
||||
|
||||
double total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
total += fw[i];
|
||||
sumfw[i] = total;
|
||||
}
|
||||
|
||||
total = 0;
|
||||
for (long i = 0; i < chain_length + 1; ++i) {
|
||||
total += bw[i];
|
||||
sumbw[i] = total;
|
||||
}
|
||||
|
||||
total = 0;
|
||||
for (long i = 0; i < chain_length; ++i) {
|
||||
total += sumfw[i];
|
||||
sumsumfw[i] = total;
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
// TODO: Can be optimized to remove the IF by reordering loops
|
||||
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
|
||||
OPT(m, i, i) = bw[i];
|
||||
else
|
||||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
if (i < chain_length) {
|
||||
long maxC = fmaxl(cw[i], cw[i + 1]);
|
||||
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
|
||||
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
|
||||
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
|
||||
else
|
||||
OPT(m, i, i + 1) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i + 2 <= chain_length; ++i) {
|
||||
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
|
||||
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
|
||||
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
|
||||
for (long l = i + 2; l <= chain_length; ++l) {
|
||||
maxCW_il = fmax(maxCW_il, cw[l + 1]);
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
|
||||
long mmin = fmaxl(mminCst, maxCostFWD);
|
||||
if ((m >= mmin)) {
|
||||
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
|
||||
noCheckpointCost +=
|
||||
sumsumfw[l - 1] -
|
||||
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
|
||||
|
||||
double valueCost = INFINITY;
|
||||
if (m >= cw[i]) {
|
||||
double sumFwds = 0;
|
||||
for (long j = i + 1; j < l; ++j) {
|
||||
sumFwds += fw[j - 1];
|
||||
valueCost = fmin(
|
||||
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
|
||||
}
|
||||
}
|
||||
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
|
||||
} else
|
||||
OPT(m, i, l) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(sumfw);
|
||||
free(sumbw);
|
||||
free(sumsumfw);
|
||||
free(fw);
|
||||
free(bw);
|
||||
free(cw);
|
||||
free(cbw);
|
||||
|
||||
PyObject* res_opt = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* res_opt_m = PyList_New(chain_length + 1);
|
||||
PyList_SET_ITEM(res_opt, m, res_opt_m);
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
PyObject* res_opt_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
|
||||
for (long l = i; l <= chain_length; ++l) {
|
||||
PyObject* res_l = PyLong_FromLong(l - i);
|
||||
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
|
||||
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
|
||||
Py_DECREF(res_opt_m_i_l);
|
||||
Py_DECREF(res_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(opt);
|
||||
|
||||
return res_opt;
|
||||
}
|
||||
|
||||
static PyMethodDef dynamic_programs_methods[] = {
|
||||
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table with the persistent algorithm."},
|
||||
{"floating_compute_table", floating_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table with the floating algorithm."},
|
||||
{"griewank_heterogeneous_compute_table",
|
||||
griewank_heterogeneous_compute_table, METH_VARARGS,
|
||||
"Compute the optimal table for the Griewank Heterogeneous Model."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef dynamic_programs_module = {
|
||||
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
|
||||
NULL, /* module documentation, may be NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
dynamic_programs_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
|
||||
return PyModule_Create(&dynamic_programs_module);
|
||||
}
|
6
setup.py
6
setup.py
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import subprocess
|
||||
import re
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import find_packages, setup, Extension
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -100,7 +100,7 @@ def get_version():
|
|||
version += f'+torch{torch_version}cu{cuda_version}'
|
||||
return version
|
||||
|
||||
|
||||
|
||||
if build_cuda_ext:
|
||||
try:
|
||||
import torch
|
||||
|
@ -115,7 +115,7 @@ if build_cuda_ext:
|
|||
except ImportError:
|
||||
print('torch is not found. CUDA extension will not be installed')
|
||||
build_cuda_ext = False
|
||||
|
||||
|
||||
if build_cuda_ext:
|
||||
build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
import torch.fx
|
||||
import colossalai
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.passes.algorithms import solver_rotor
|
||||
from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
withcodegen = True
|
||||
except:
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
withcodegen = False
|
||||
|
||||
|
||||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
|
||||
model = M()
|
||||
data = torch.rand(128, 3, 224, 224, device='meta')
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"x": data})
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
if META_COMPATIBILITY:
|
||||
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data_meta)
|
||||
|
||||
# python solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
|
||||
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
|
||||
# C solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
|
||||
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
|
||||
sequence_python = sequence_python.list_operations()
|
||||
sequence_C = sequence_C.list_operations()
|
||||
|
||||
# make sure the solutions are the same
|
||||
assert len(sequence_python) == len(sequence_C) and \
|
||||
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
|
||||
def test_C_solver_consistency():
|
||||
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_run_C_solver_consistency_test(rank=0)
|
Loading…
Reference in New Issue