mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] rotor solver refactor (#2813)
* [autoparallel] rotor solver refactor * [autoparallel] rotor solver refactorpull/2826/head
parent
09f457479d
commit
8593ae1a3f
|
@ -1,6 +1,12 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
/*
|
||||
Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
|
||||
paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
|
||||
networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
|
||||
the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
|
||||
*/
|
||||
long* PySequenceToLongArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
|
@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
|||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chainLength; ++i)
|
||||
for (long i = 0; i <= chainLength; ++i) {
|
||||
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
||||
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
|
||||
(m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
|
||||
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
||||
else
|
||||
} else {
|
||||
COST_TABLE(m, i, i) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
for (long d = 1; d <= chainLength; ++d) {
|
||||
for (long i = 0; i <= chainLength - d; ++i) {
|
||||
long idx = i + d;
|
||||
|
@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
|||
}
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= xbar[i + 1])
|
||||
if (m >= xbar[i + 1]) {
|
||||
chainCost =
|
||||
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
||||
}
|
||||
if (bestLeafCost <= chainCost) {
|
||||
COST_TABLE(m, i, idx) = bestLeafCost;
|
||||
BACK_PTR(m, i, idx) = bestLeaf;
|
||||
|
@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
|||
COST_TABLE(m, i, idx) = chainCost;
|
||||
BACK_PTR(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
} else {
|
||||
COST_TABLE(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(ftime);
|
||||
free(btime);
|
||||
|
@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
|
|||
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
||||
Py_DECREF(pyCostTable_m_i_l);
|
||||
PyObject* pyBackPtr_m_i_l;
|
||||
if (BACK_PTR(m, i, l) < 0)
|
||||
if (BACK_PTR(m, i, l) < 0) {
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||
else
|
||||
} else {
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
||||
}
|
||||
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyVar_l);
|
||||
|
|
|
@ -207,11 +207,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
mmax (int): Maximum number of memory slots.
|
||||
|
||||
Returns:
|
||||
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
|
||||
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
|
||||
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
|
||||
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
|
||||
of length j
|
||||
cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
|
||||
with m memory slots.
|
||||
back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
|
||||
is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
|
||||
"""
|
||||
|
||||
ftime = chain.ftime + [0.0]
|
||||
|
@ -224,18 +223,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
# Build table
|
||||
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
# Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
|
||||
for m in range(mmax + 1):
|
||||
for i in range(len(chain) + 1):
|
||||
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
||||
if m >= limit: # Equation (1)
|
||||
if m >= limit:
|
||||
cost_table[m][i][i] = ftime[i] + btime[i]
|
||||
else:
|
||||
cost_table[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
# Compute tables
|
||||
for m in range(mmax + 1):
|
||||
for d in range(1, len(chain) + 1):
|
||||
for i in range(len(chain) + 1 - d):
|
||||
|
|
Loading…
Reference in New Issue