mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] rotor solver refactor (#2813)
* [autoparallel] rotor solver refactor * [autoparallel] rotor solver refactorpull/2826/head
parent
09f457479d
commit
8593ae1a3f
|
@ -1,6 +1,12 @@
|
||||||
#define PY_SSIZE_T_CLEAN
|
#define PY_SSIZE_T_CLEAN
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
|
/*
|
||||||
|
Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
|
||||||
|
paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
|
||||||
|
networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
|
||||||
|
the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
|
||||||
|
*/
|
||||||
long* PySequenceToLongArray(PyObject* pylist) {
|
long* PySequenceToLongArray(PyObject* pylist) {
|
||||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||||
Py_ssize_t len = PySequence_Size(pylist);
|
Py_ssize_t len = PySequence_Size(pylist);
|
||||||
|
@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
for (long m = 0; m <= mmax; ++m)
|
||||||
for (long i = 0; i <= chainLength; ++i)
|
for (long i = 0; i <= chainLength; ++i) {
|
||||||
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
||||||
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
|
(m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
|
||||||
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
||||||
else
|
} else {
|
||||||
COST_TABLE(m, i, i) = INFINITY;
|
COST_TABLE(m, i, i) = INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (long m = 0; m <= mmax; ++m)
|
for (long m = 0; m <= mmax; ++m) {
|
||||||
for (long d = 1; d <= chainLength; ++d) {
|
for (long d = 1; d <= chainLength; ++d) {
|
||||||
for (long i = 0; i <= chainLength - d; ++i) {
|
for (long i = 0; i <= chainLength - d; ++i) {
|
||||||
long idx = i + d;
|
long idx = i + d;
|
||||||
|
@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
double chainCost = INFINITY;
|
double chainCost = INFINITY;
|
||||||
if (m >= xbar[i + 1])
|
if (m >= xbar[i + 1]) {
|
||||||
chainCost =
|
chainCost =
|
||||||
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
||||||
|
}
|
||||||
if (bestLeafCost <= chainCost) {
|
if (bestLeafCost <= chainCost) {
|
||||||
COST_TABLE(m, i, idx) = bestLeafCost;
|
COST_TABLE(m, i, idx) = bestLeafCost;
|
||||||
BACK_PTR(m, i, idx) = bestLeaf;
|
BACK_PTR(m, i, idx) = bestLeaf;
|
||||||
|
@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||||
COST_TABLE(m, i, idx) = chainCost;
|
COST_TABLE(m, i, idx) = chainCost;
|
||||||
BACK_PTR(m, i, idx) = -1;
|
BACK_PTR(m, i, idx) = -1;
|
||||||
}
|
}
|
||||||
} else
|
} else {
|
||||||
COST_TABLE(m, i, idx) = INFINITY;
|
COST_TABLE(m, i, idx) = INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
free(ftime);
|
free(ftime);
|
||||||
free(btime);
|
free(btime);
|
||||||
|
@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||||
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
||||||
Py_DECREF(pyCostTable_m_i_l);
|
Py_DECREF(pyCostTable_m_i_l);
|
||||||
PyObject* pyBackPtr_m_i_l;
|
PyObject* pyBackPtr_m_i_l;
|
||||||
if (BACK_PTR(m, i, l) < 0)
|
if (BACK_PTR(m, i, l) < 0) {
|
||||||
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||||
else
|
} else {
|
||||||
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
||||||
|
}
|
||||||
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
||||||
Py_DECREF(pyBackPtr_m_i_l);
|
Py_DECREF(pyBackPtr_m_i_l);
|
||||||
Py_DECREF(pyVar_l);
|
Py_DECREF(pyVar_l);
|
||||||
|
|
|
@ -207,11 +207,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
mmax (int): Maximum number of memory slots.
|
mmax (int): Maximum number of memory slots.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
|
cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
|
||||||
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
|
with m memory slots.
|
||||||
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
|
back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
|
||||||
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
|
is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
|
||||||
of length j
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ftime = chain.ftime + [0.0]
|
ftime = chain.ftime + [0.0]
|
||||||
|
@ -224,18 +223,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
# Build table
|
# Build table
|
||||||
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||||
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
|
||||||
|
|
||||||
# Initialize borders of the tables for lmax-lmin = 0
|
# Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
|
||||||
for m in range(mmax + 1):
|
for m in range(mmax + 1):
|
||||||
for i in range(len(chain) + 1):
|
for i in range(len(chain) + 1):
|
||||||
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
||||||
if m >= limit: # Equation (1)
|
if m >= limit:
|
||||||
cost_table[m][i][i] = ftime[i] + btime[i]
|
cost_table[m][i][i] = ftime[i] + btime[i]
|
||||||
else:
|
else:
|
||||||
cost_table[m][i][i] = float("inf")
|
cost_table[m][i][i] = float("inf")
|
||||||
|
|
||||||
# Compute everything
|
# Compute tables
|
||||||
for m in range(mmax + 1):
|
for m in range(mmax + 1):
|
||||||
for d in range(1, len(chain) + 1):
|
for d in range(1, len(chain) + 1):
|
||||||
for i in range(len(chain) + 1 - d):
|
for i in range(len(chain) + 1 - d):
|
||||||
|
|
Loading…
Reference in New Issue