mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
6.9 KiB
210 lines
6.9 KiB
#define PY_SSIZE_T_CLEAN
|
|
#include <Python.h>
|
|
|
|
/*
|
|
Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
|
|
paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
|
|
networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
|
|
the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
|
|
*/
|
|
long* PySequenceToLongArray(PyObject* pylist) {
|
|
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
|
Py_ssize_t len = PySequence_Size(pylist);
|
|
long* result = (long*)calloc(len + 1, sizeof(long));
|
|
for (Py_ssize_t i = 0; i < len; ++i) {
|
|
PyObject* item = PySequence_GetItem(pylist, i);
|
|
result[i] = PyLong_AsLong(item);
|
|
Py_DECREF(item);
|
|
}
|
|
result[len] = 0;
|
|
return result;
|
|
}
|
|
|
|
double* PySequenceToDoubleArray(PyObject* pylist) {
|
|
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
|
Py_ssize_t len = PySequence_Size(pylist);
|
|
double* result = (double*)calloc(len + 1, sizeof(double));
|
|
for (Py_ssize_t i = 0; i < len; ++i) {
|
|
PyObject* item = PySequence_GetItem(pylist, i);
|
|
result[i] = PyFloat_AsDouble(item);
|
|
Py_DECREF(item);
|
|
}
|
|
result[len] = 0;
|
|
return result;
|
|
}
|
|
|
|
long* getLongArray(PyObject* container, const char* attributeName) {
|
|
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
|
long* result = PySequenceToLongArray(sequence);
|
|
Py_DECREF(sequence);
|
|
return result;
|
|
}
|
|
|
|
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
|
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
|
double* result = PySequenceToDoubleArray(sequence);
|
|
Py_DECREF(sequence);
|
|
return result;
|
|
}
|
|
|
|
static PyObject* computeTable(PyObject* self, PyObject* args) {
|
|
PyObject* chainParam;
|
|
int mmax;
|
|
|
|
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
|
|
|
|
double* ftime = getDoubleArray(chainParam, "ftime");
|
|
if (!ftime) return NULL;
|
|
|
|
double* btime = getDoubleArray(chainParam, "btime");
|
|
if (!btime) return NULL;
|
|
|
|
long* x = getLongArray(chainParam, "x");
|
|
if (!x) return NULL;
|
|
|
|
long* xbar = getLongArray(chainParam, "xbar");
|
|
if (!xbar) return NULL;
|
|
|
|
long* ftmp = getLongArray(chainParam, "btmp");
|
|
if (!ftmp) return NULL;
|
|
|
|
long* btmp = getLongArray(chainParam, "btmp");
|
|
if (!btmp) return NULL;
|
|
|
|
long chainLength = PyObject_Length(chainParam);
|
|
if (!chainLength) return NULL;
|
|
|
|
#define COST_TABLE(m, i, l) \
|
|
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
|
|
(i) * (chainLength + 1) + (l)]
|
|
double* costTable = (double*)calloc(
|
|
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
|
|
|
|
#define BACK_PTR(m, i, l) \
|
|
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
|
|
(i) * (chainLength + 1) + (l)]
|
|
long* backPtr = (long*)calloc(
|
|
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
|
|
|
for (long m = 0; m <= mmax; ++m)
|
|
for (long i = 0; i <= chainLength; ++i) {
|
|
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
|
(m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
|
|
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
|
} else {
|
|
COST_TABLE(m, i, i) = INFINITY;
|
|
}
|
|
}
|
|
|
|
for (long m = 0; m <= mmax; ++m) {
|
|
for (long d = 1; d <= chainLength; ++d) {
|
|
for (long i = 0; i <= chainLength - d; ++i) {
|
|
long idx = i + d;
|
|
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
|
|
if (idx > i + 1) {
|
|
long maxCostFWD = 0;
|
|
for (long j = i + 1; j < idx; j++) {
|
|
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
|
|
}
|
|
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
|
|
}
|
|
if ((m >= mmin)) {
|
|
long bestLeaf = -1;
|
|
double sumFw = 0;
|
|
double bestLeafCost = INFINITY;
|
|
for (long j = i + 1; j <= idx; ++j) {
|
|
sumFw += ftime[j - 1];
|
|
if (m >= x[j]) {
|
|
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
|
|
COST_TABLE(m, i, j - 1);
|
|
if (cost < bestLeafCost) {
|
|
bestLeafCost = cost;
|
|
bestLeaf = j;
|
|
}
|
|
}
|
|
}
|
|
double chainCost = INFINITY;
|
|
if (m >= xbar[i + 1]) {
|
|
chainCost =
|
|
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
|
}
|
|
if (bestLeafCost <= chainCost) {
|
|
COST_TABLE(m, i, idx) = bestLeafCost;
|
|
BACK_PTR(m, i, idx) = bestLeaf;
|
|
} else {
|
|
COST_TABLE(m, i, idx) = chainCost;
|
|
BACK_PTR(m, i, idx) = -1;
|
|
}
|
|
} else {
|
|
COST_TABLE(m, i, idx) = INFINITY;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
free(ftime);
|
|
free(btime);
|
|
free(x);
|
|
free(xbar);
|
|
free(ftmp);
|
|
free(btmp);
|
|
|
|
PyObject* pyCostTable = PyList_New(mmax + 1);
|
|
PyObject* pyBackPtr = PyList_New(mmax + 1);
|
|
|
|
// Convert the result into Python world
|
|
for (long m = 0; m <= mmax; ++m) {
|
|
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
|
|
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
|
|
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
|
|
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
|
|
for (long i = 0; i <= chainLength; ++i) {
|
|
PyObject* pyCostTable_m_i = PyDict_New();
|
|
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
|
|
PyObject* pyBackPtr_m_i = PyDict_New();
|
|
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
|
|
for (long l = i; l <= chainLength; ++l) {
|
|
PyObject* pyVar_l = PyLong_FromLong(l);
|
|
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
|
|
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
|
Py_DECREF(pyCostTable_m_i_l);
|
|
PyObject* pyBackPtr_m_i_l;
|
|
if (BACK_PTR(m, i, l) < 0) {
|
|
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
|
} else {
|
|
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
|
}
|
|
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
|
Py_DECREF(pyBackPtr_m_i_l);
|
|
Py_DECREF(pyVar_l);
|
|
}
|
|
}
|
|
}
|
|
|
|
free(costTable);
|
|
free(backPtr);
|
|
|
|
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
|
|
Py_DECREF(pyCostTable);
|
|
Py_DECREF(pyBackPtr);
|
|
return result;
|
|
}
|
|
|
|
static PyMethodDef rotorMethods[] = {
|
|
{"compute_table", computeTable, METH_VARARGS,
|
|
"Compute the optimal table with the rotor algorithm."},
|
|
{NULL, NULL, 0, NULL} /* Sentinel */
|
|
};
|
|
|
|
static struct PyModuleDef rotorModule = {
|
|
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
|
|
"A simple implementation of dynamic programming algorithm rotor with C in "
|
|
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
|
|
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
|
|
NULL */
|
|
-1, /* size of per-interpreter state of the module,
|
|
or -1 if the module keeps state in global variables. */
|
|
rotorMethods};
|
|
|
|
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }
|