[autoparallel] rotor solver refactor (#2813)

* [autoparallel] rotor solver refactor

* [autoparallel] rotor solver refactor
pull/2826/head
Boyuan Yao 2023-02-18 11:30:15 +08:00 committed by GitHub
parent 09f457479d
commit 8593ae1a3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 17 deletions

View File

@ -1,6 +1,12 @@
#define PY_SSIZE_T_CLEAN #define PY_SSIZE_T_CLEAN
#include <Python.h> #include <Python.h>
/*
Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
*/
long* PySequenceToLongArray(PyObject* pylist) { long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL; if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist); Py_ssize_t len = PySequence_Size(pylist);
@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m) for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chainLength; ++i) for (long i = 0; i <= chainLength; ++i) {
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
(m >= x[i + 1] + xbar[i + 1] + ftmp[i])) (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
COST_TABLE(m, i, i) = ftime[i] + btime[i]; COST_TABLE(m, i, i) = ftime[i] + btime[i];
else } else {
COST_TABLE(m, i, i) = INFINITY; COST_TABLE(m, i, i) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m) for (long m = 0; m <= mmax; ++m) {
for (long d = 1; d <= chainLength; ++d) { for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) { for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d; long idx = i + d;
@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
} }
} }
double chainCost = INFINITY; double chainCost = INFINITY;
if (m >= xbar[i + 1]) if (m >= xbar[i + 1]) {
chainCost = chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
}
if (bestLeafCost <= chainCost) { if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost; COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf; BACK_PTR(m, i, idx) = bestLeaf;
@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
COST_TABLE(m, i, idx) = chainCost; COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1; BACK_PTR(m, i, idx) = -1;
} }
} else } else {
COST_TABLE(m, i, idx) = INFINITY; COST_TABLE(m, i, idx) = INFINITY;
} }
} }
}
}
free(ftime); free(ftime);
free(btime); free(btime);
@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l); Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l; PyObject* pyBackPtr_m_i_l;
if (BACK_PTR(m, i, l) < 0) if (BACK_PTR(m, i, l) < 0) {
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
else } else {
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
}
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l); Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l); Py_DECREF(pyVar_l);

View File

@ -207,11 +207,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
mmax (int): Maximum number of memory slots. mmax (int): Maximum number of memory slots.
Returns: Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax with m memory slots.
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
of length j
""" """
ftime = chain.ftime + [0.0] ftime = chain.ftime + [0.0]
@ -224,18 +223,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# Build table # Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0 # Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
for m in range(mmax + 1): for m in range(mmax + 1):
for i in range(len(chain) + 1): for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1) if m >= limit:
cost_table[m][i][i] = ftime[i] + btime[i] cost_table[m][i][i] = ftime[i] + btime[i]
else: else:
cost_table[m][i][i] = float("inf") cost_table[m][i][i] = float("inf")
# Compute everything # Compute tables
for m in range(mmax + 1): for m in range(mmax + 1):
for d in range(1, len(chain) + 1): for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d): for i in range(len(chain) + 1 - d):