mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactor and add rotorc. (#1789)
* [autoparallel] refactor and add rotorc. * [autoparallel] refactor and add rotorc.pull/1783/head
parent
4d6e1284cb
commit
e8a9bebc87
|
@ -0,0 +1,16 @@
|
|||
import os
|
||||
|
||||
from setuptools import Extension, setup
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_modules = [Extension(
|
||||
'rotorc',
|
||||
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
|
||||
)]
|
||||
|
||||
setup(
|
||||
name='rotor c extension',
|
||||
version='0.1',
|
||||
description='rotor c extension for faster dp computing',
|
||||
ext_modules=ext_modules,
|
||||
)
|
|
@ -0,0 +1,197 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
long* PySequenceToLongArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
long* result = (long*)calloc(len + 1, sizeof(long));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyLong_AsLong(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
double* PySequenceToDoubleArray(PyObject* pylist) {
|
||||
if (!(pylist && PySequence_Check(pylist))) return NULL;
|
||||
Py_ssize_t len = PySequence_Size(pylist);
|
||||
double* result = (double*)calloc(len + 1, sizeof(double));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
PyObject* item = PySequence_GetItem(pylist, i);
|
||||
result[i] = PyFloat_AsDouble(item);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
result[len] = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
long* getLongArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
long* result = PySequenceToLongArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
double* getDoubleArray(PyObject* container, const char* attributeName) {
|
||||
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
|
||||
double* result = PySequenceToDoubleArray(sequence);
|
||||
Py_DECREF(sequence);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* computeTable(PyObject* self, PyObject* args) {
|
||||
PyObject* chainParam;
|
||||
int mmax;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
|
||||
|
||||
double* ftime = getDoubleArray(chainParam, "ftime");
|
||||
if (!ftime) return NULL;
|
||||
|
||||
double* btime = getDoubleArray(chainParam, "btime");
|
||||
if (!btime) return NULL;
|
||||
|
||||
long* x = getLongArray(chainParam, "x");
|
||||
if (!x) return NULL;
|
||||
|
||||
long* xbar = getLongArray(chainParam, "xbar");
|
||||
if (!xbar) return NULL;
|
||||
|
||||
long* ftmp = getLongArray(chainParam, "btmp");
|
||||
if (!ftmp) return NULL;
|
||||
|
||||
long* btmp = getLongArray(chainParam, "btmp");
|
||||
if (!btmp) return NULL;
|
||||
|
||||
long chainLength = PyObject_Length(chainParam);
|
||||
if (!chainLength) return NULL;
|
||||
|
||||
#define COST_TABLE(m, i, l) \
|
||||
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
|
||||
(i) * (chainLength + 1) + (l)]
|
||||
double* costTable = (double*)calloc(
|
||||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
|
||||
|
||||
#define BACK_PTR(m, i, l) \
|
||||
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
|
||||
(i) * (chainLength + 1) + (l)]
|
||||
long* backPtr = (long*)calloc(
|
||||
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chainLength; ++i)
|
||||
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
|
||||
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
|
||||
COST_TABLE(m, i, i) = ftime[i] + btime[i];
|
||||
else
|
||||
COST_TABLE(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long d = 1; d <= chainLength; ++d) {
|
||||
for (long i = 0; i <= chainLength - d; ++i) {
|
||||
long idx = i + d;
|
||||
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
|
||||
if (idx > i + 1) {
|
||||
long maxCostFWD = 0;
|
||||
for (long j = i + 1; j < idx; j++) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
|
||||
}
|
||||
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
double sumFw = 0;
|
||||
double bestLeafCost = INFINITY;
|
||||
for (long j = i + 1; j <= idx; ++j) {
|
||||
sumFw += ftime[j - 1];
|
||||
if (m >= x[j]) {
|
||||
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
|
||||
COST_TABLE(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= xbar[i + 1])
|
||||
chainCost =
|
||||
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
COST_TABLE(m, i, idx) = bestLeafCost;
|
||||
BACK_PTR(m, i, idx) = bestLeaf;
|
||||
} else {
|
||||
COST_TABLE(m, i, idx) = chainCost;
|
||||
BACK_PTR(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
COST_TABLE(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
free(ftime);
|
||||
free(btime);
|
||||
free(x);
|
||||
free(xbar);
|
||||
free(ftmp);
|
||||
free(btmp);
|
||||
|
||||
PyObject* pyCostTable = PyList_New(mmax + 1);
|
||||
PyObject* pyBackPtr = PyList_New(mmax + 1);
|
||||
|
||||
// Convert the result into Python world
|
||||
for (long m = 0; m <= mmax; ++m) {
|
||||
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
|
||||
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
|
||||
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
|
||||
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
|
||||
for (long i = 0; i <= chainLength; ++i) {
|
||||
PyObject* pyCostTable_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
|
||||
PyObject* pyBackPtr_m_i = PyDict_New();
|
||||
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
|
||||
for (long l = i; l <= chainLength; ++l) {
|
||||
PyObject* pyVar_l = PyLong_FromLong(l);
|
||||
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
|
||||
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
|
||||
Py_DECREF(pyCostTable_m_i_l);
|
||||
PyObject* pyBackPtr_m_i_l;
|
||||
if (BACK_PTR(m, i, l) < 0)
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
|
||||
else
|
||||
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
|
||||
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyBackPtr_m_i_l);
|
||||
Py_DECREF(pyVar_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(costTable);
|
||||
free(backPtr);
|
||||
|
||||
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
|
||||
Py_DECREF(pyCostTable);
|
||||
Py_DECREF(pyBackPtr);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyMethodDef rotorMethods[] = {
|
||||
{"compute_table", computeTable, METH_VARARGS,
|
||||
"Compute the optimal table with the rotor algorithm."},
|
||||
{NULL, NULL, 0, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef rotorModule = {
|
||||
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
|
||||
"A simple implementation of dynamic programming algorithm rotor with C in "
|
||||
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
|
||||
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
|
||||
NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
rotorMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }
|
|
@ -1,5 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.fx import Graph, Node
|
||||
|
@ -15,9 +15,9 @@ from colossalai.fx.profiler import (
|
|||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
|
||||
|
||||
__all__ = ['CheckpointSolverBase']
|
||||
__all__ = ['CheckpointSolverRotor']
|
||||
|
||||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
@ -59,11 +59,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
self.back_ptr = None
|
||||
self.sequence = None
|
||||
|
||||
def solve(self, force_python: bool = False) -> Graph:
|
||||
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
|
||||
"""Solve the checkpointing problem using rotor algorithm.
|
||||
|
||||
Args:
|
||||
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
|
||||
verbose (bool, optional): Print verbose information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
|
@ -76,14 +77,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
else:
|
||||
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
|
||||
|
||||
if verbose:
|
||||
self.print_chain()
|
||||
|
||||
# backtrack
|
||||
try:
|
||||
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr)
|
||||
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
|
||||
self.back_ptr)
|
||||
self._annotate_from_sequence(self.sequence, self.node_list)
|
||||
except RuntimeError as e:
|
||||
except ValueError as e:
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'Checkpoint solver failed: {e}')
|
||||
raise ValueError
|
||||
|
||||
if verbose:
|
||||
self.print_sequence()
|
||||
|
||||
return deepcopy(self.graph)
|
||||
|
||||
|
@ -100,42 +109,42 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
@classmethod
|
||||
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
|
||||
input_tensors = cls._extract_input(graph)
|
||||
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list()
|
||||
ftime, btime, ftmp, btmp = list(), list(), list(), list()
|
||||
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
for node in node_list:
|
||||
node_info = cls._extract_node_info(node)
|
||||
fwd_time.append(node_info[0])
|
||||
bwd_time.append(node_info[1])
|
||||
ftime.append(node_info[0])
|
||||
btime.append(node_info[1])
|
||||
x.append(node_info[2])
|
||||
xbar.append(node_info[3])
|
||||
ftmp.append(node_info[4])
|
||||
btmp.append(node_info[5])
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
bwd_time.append(0)
|
||||
btime.append(0)
|
||||
btmp.append(0)
|
||||
|
||||
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp)
|
||||
return Chain(ftime, btime, x, xbar, ftmp, btmp)
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
|
||||
"""Extract node info from a list of nodes"""
|
||||
xbar = 0
|
||||
fwd_time = 0
|
||||
bwd_time = 0
|
||||
ftime = 0
|
||||
btime = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
# minimum flop count is required
|
||||
fwd_time += max(calculate_fwd_time(n), 1.0)
|
||||
bwd_time += max(calculate_bwd_time(n), 1.0)
|
||||
ftime += max(calculate_fwd_time(n), 1.0)
|
||||
btime += max(calculate_bwd_time(n), 1.0)
|
||||
|
||||
x = calculate_fwd_out(node[-1])
|
||||
xbar = max(x, xbar)
|
||||
ftmp = cls._extract_ftmp(node)
|
||||
btmp = cls._extract_btmp(node)
|
||||
return fwd_time, bwd_time, x, xbar, ftmp, btmp
|
||||
return ftime, btime, x, xbar, ftmp, btmp
|
||||
|
||||
@staticmethod
|
||||
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
|
||||
|
@ -180,17 +189,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
return btmp
|
||||
|
||||
@staticmethod
|
||||
def _compute_table(chain: Chain, mem_slots: int) -> Tuple:
|
||||
def _compute_table(chain: Chain, mmax: int) -> Tuple:
|
||||
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
mem_slots (int): Number of slots for discretizing memory budget.
|
||||
mmax (int): Maximum number of memory slots.
|
||||
|
||||
Returns:
|
||||
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice
|
||||
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
|
||||
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
|
||||
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
|
||||
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
|
||||
of length j
|
||||
"""
|
||||
|
@ -203,13 +212,13 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
btmp = chain.btmp + [0]
|
||||
|
||||
# Build table
|
||||
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
|
||||
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
|
||||
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mem_slots + 1):
|
||||
for i in range(chain.length + 1):
|
||||
for m in range(mmax + 1):
|
||||
for i in range(len(chain) + 1):
|
||||
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
||||
if m >= limit: # Equation (1)
|
||||
cost_table[m][i][i] = ftime[i] + btime[i]
|
||||
|
@ -217,9 +226,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
cost_table[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
for m in range(mem_slots + 1):
|
||||
for d in range(1, chain.length + 1):
|
||||
for i in range(chain.length + 1 - d):
|
||||
for m in range(mmax + 1):
|
||||
for d in range(1, len(chain) + 1):
|
||||
for i in range(len(chain) + 1 - d):
|
||||
idx = i + d
|
||||
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
|
||||
if idx > i + 1:
|
||||
|
@ -248,20 +257,46 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
return cost_table, back_ptr
|
||||
|
||||
@staticmethod
|
||||
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple:
|
||||
raise NotImplementedError("C implementation not available yet")
|
||||
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
|
||||
try:
|
||||
from .rotorc import compute_table
|
||||
|
||||
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]],
|
||||
back_ptr: List[List[Dict[int, int]]]) -> List[int]:
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
logger = get_dist_logger()
|
||||
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
[
|
||||
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
|
||||
f"--build-lib={this_dir}"
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if result.wait() == 0:
|
||||
logger.info("rotorc has been built!", ranks=[0])
|
||||
from .rotorc import compute_table
|
||||
else:
|
||||
logger.warning("rotorc built failed! Using python version!", ranks=[0])
|
||||
return CheckpointSolverRotor._compute_table(chain, mmax)
|
||||
return compute_table(chain, mmax)
|
||||
|
||||
@staticmethod
|
||||
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
|
||||
back_ptr: List[Any]) -> "Sequence":
|
||||
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
lmin (int): The left index of the interval to backtrack.
|
||||
lmax (int): The right index of the interval to backtrack.
|
||||
mem_budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
|
||||
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
|
||||
lhs (int): The left index of the interval to backtrack.
|
||||
rhs (int): The right index of the interval to backtrack.
|
||||
budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[Any]): See `._compute_table()` for definitions
|
||||
back_ptr (List[Any]): See `._compute_table()` for definitions
|
||||
|
||||
Raises:
|
||||
ValueError: Can not process the chain.
|
||||
|
@ -269,36 +304,45 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
Returns:
|
||||
sequence (Sequence): The sequence of executing nodes with checkpoints.
|
||||
"""
|
||||
if mem_budget <= 0:
|
||||
raise ValueError(f"Can not process a chain with negative memory {mem_budget}")
|
||||
elif cost_table[mem_budget][lmin][lmax] == float("inf"):
|
||||
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}")
|
||||
if budget <= 0:
|
||||
raise ValueError(f"Can not process a chain with negative memory {budget}")
|
||||
elif cost_table[budget][lhs][rhs] == float("inf"):
|
||||
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
|
||||
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget))
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
sequence = Sequence()
|
||||
if rhs == lhs:
|
||||
if lhs == len(chain):
|
||||
sequence += [Loss()]
|
||||
else:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert(Backward(lmin))
|
||||
sequence += [ForwardEnable(lhs), Backward(lhs)]
|
||||
return sequence
|
||||
|
||||
if back_ptr[mem_budget][lmin][lmax][0]:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert_sequence(
|
||||
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr))
|
||||
sequence.insert(Backward(lmin))
|
||||
if back_ptr[budget][lhs][rhs][0]:
|
||||
sequence += [
|
||||
ForwardEnable(lhs),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
|
||||
back_ptr),
|
||||
Backward(lhs),
|
||||
]
|
||||
else:
|
||||
j = back_ptr[mem_budget][lmin][lmax][1]
|
||||
sequence.insert(ForwardCheck(lmin))
|
||||
for k in range(lmin + 1, j):
|
||||
sequence.insert(ForwardNograd(k))
|
||||
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr))
|
||||
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr))
|
||||
best_leaf = back_ptr[budget][lhs][rhs][1]
|
||||
sequence += [ForwardCheck(lhs)]
|
||||
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
|
||||
sequence += [
|
||||
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
|
||||
back_ptr),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
|
||||
]
|
||||
return sequence
|
||||
|
||||
@staticmethod
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
|
||||
|
||||
Args:
|
||||
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
||||
node_list (List[List[Node]]): The list of nodes to annotate.
|
||||
"""
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -33,23 +33,25 @@ class Chain:
|
|||
self.xbar = xbar
|
||||
self.ftmp = ftmp
|
||||
self.btmp = btmp
|
||||
self.length = len(ftime)
|
||||
if check_consistency and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1)
|
||||
and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length)
|
||||
and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1))
|
||||
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
for i in range(len(self)):
|
||||
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
|
||||
i = self.length
|
||||
i = len(self)
|
||||
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ftime)
|
||||
|
||||
def discretize_all(self, unit: int):
|
||||
"""Discretize the chain into a list of chains according to unit size."""
|
||||
discretizer = lambda val: math.ceil(val / unit)
|
||||
|
@ -163,79 +165,20 @@ class DiscardMemory(MemoryAccess):
|
|||
name = "DM"
|
||||
|
||||
|
||||
class Function:
|
||||
class Sequence(list):
|
||||
|
||||
def __init__(self, name, *args):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.str_args = ','.join(str(v) for v in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}({args})".format(n=self.name, args=self.str_args)
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(self, function):
|
||||
self.sequence = [] #List of Operation and Sequence
|
||||
self.function = function #Description the function (name and parameters)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.list_operations())
|
||||
|
||||
def list_operations(self):
|
||||
op_list = []
|
||||
for x in self.sequence:
|
||||
for x in self:
|
||||
if isinstance(x, Operation):
|
||||
op_list.append(x)
|
||||
else:
|
||||
assert isinstance(x, Sequence)
|
||||
op_list += x.list_operations()
|
||||
return op_list
|
||||
|
||||
def insert(self, operation):
|
||||
self.sequence.append(operation)
|
||||
|
||||
def remove(self, operation_index):
|
||||
del self.sequence[operation_index]
|
||||
|
||||
def insert_sequence(self, sequence):
|
||||
self.sequence.append(sequence)
|
||||
|
||||
def shift(self, value):
|
||||
for x in self.sequence:
|
||||
x.shift(value)
|
||||
return self
|
||||
|
||||
def remove_useless_write(self):
|
||||
if self.sequence:
|
||||
if isinstance(self.sequence[0], WriteMemory):
|
||||
self.remove(0)
|
||||
return self
|
||||
|
||||
def get_makespan(self, chain):
|
||||
return sum(op.cost(chain) for op in self.list_operations())
|
||||
|
||||
def without_suffix(self):
|
||||
ops = self.list_operations()
|
||||
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
|
||||
try:
|
||||
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
|
||||
except ValueError:
|
||||
last_idx = -1
|
||||
if last_idx == end_of_first_phase - 1:
|
||||
return (self, None)
|
||||
chain_length = ops[end_of_first_phase -
|
||||
1].index ## Some assumption here about the sequence (finishes with Forward_L
|
||||
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
|
||||
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
|
||||
for i in range(last_idx + 1):
|
||||
result.insert(ops[i])
|
||||
result.insert(Loss())
|
||||
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
|
||||
position = end_of_first_phase + 1 + (chain_length - i)
|
||||
assert type(ops[position]) is Backward
|
||||
assert ops[position].index == i
|
||||
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
|
||||
result.insert(ops[i])
|
||||
return (result, start_of_fwd_enable_chain)
|
||||
|
|
|
@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
kwargs['inplace'] = True
|
||||
meta.bwd_mem_tmp = 0
|
||||
meta.bwd_mem_out = 0
|
||||
do_not_cache = False
|
||||
|
||||
meta.bwd_mem_out -= param_size
|
||||
|
@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
|||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
module.inplace = True
|
||||
meta.bwd_mem_tmp = 0
|
||||
meta.bwd_mem_out = 0
|
||||
do_not_cache = False
|
||||
|
||||
# grad for param will not be counted
|
||||
|
|
Loading…
Reference in New Issue