[autoparallel] rotor solver refactor (#2813)

* [autoparallel] rotor solver refactor

* [autoparallel] rotor solver refactor
pull/2826/head
Boyuan Yao 2023-02-18 11:30:15 +08:00 committed by GitHub
parent 09f457479d
commit 8593ae1a3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 17 deletions

View File

@ -1,6 +1,12 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
/*
Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
*/
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chainLength; ++i)
for (long i = 0; i <= chainLength; ++i) {
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
(m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
COST_TABLE(m, i, i) = ftime[i] + btime[i];
else
} else {
COST_TABLE(m, i, i) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m)
for (long m = 0; m <= mmax; ++m) {
for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d;
@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
}
}
double chainCost = INFINITY;
if (m >= xbar[i + 1])
if (m >= xbar[i + 1]) {
chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
}
if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf;
@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1;
}
} else
} else {
COST_TABLE(m, i, idx) = INFINITY;
}
}
}
}
free(ftime);
free(btime);
@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l;
if (BACK_PTR(m, i, l) < 0)
if (BACK_PTR(m, i, l) < 0) {
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
else
} else {
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
}
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l);

View File

@ -207,11 +207,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
with m memory slots.
back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
"""
ftime = chain.ftime + [0.0]
@ -224,18 +223,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
# Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
if m >= limit:
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
# Compute tables
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):