[autoparallel] refactor and add rotorc. (#1789)

* [autoparallel] refactor and add rotorc.

* [autoparallel] refactor and add rotorc.
pull/1783/head
Super Daniel 2022-11-03 12:32:51 +08:00 committed by GitHub
parent 4d6e1284cb
commit e8a9bebc87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 333 additions and 129 deletions

View File

@ -0,0 +1,16 @@
import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'rotorc',
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)

View File

@ -0,0 +1,197 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* computeTable(PyObject* self, PyObject* args) {
PyObject* chainParam;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
double* ftime = getDoubleArray(chainParam, "ftime");
if (!ftime) return NULL;
double* btime = getDoubleArray(chainParam, "btime");
if (!btime) return NULL;
long* x = getLongArray(chainParam, "x");
if (!x) return NULL;
long* xbar = getLongArray(chainParam, "xbar");
if (!xbar) return NULL;
long* ftmp = getLongArray(chainParam, "btmp");
if (!ftmp) return NULL;
long* btmp = getLongArray(chainParam, "btmp");
if (!btmp) return NULL;
long chainLength = PyObject_Length(chainParam);
if (!chainLength) return NULL;
#define COST_TABLE(m, i, l) \
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
double* costTable = (double*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
#define BACK_PTR(m, i, l) \
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
long* backPtr = (long*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chainLength; ++i)
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
COST_TABLE(m, i, i) = ftime[i] + btime[i];
else
COST_TABLE(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d;
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
}
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
for (long j = i + 1; j <= idx; ++j) {
sumFw += ftime[j - 1];
if (m >= x[j]) {
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
COST_TABLE(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= xbar[i + 1])
chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf;
} else {
COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1;
}
} else
COST_TABLE(m, i, idx) = INFINITY;
}
}
free(ftime);
free(btime);
free(x);
free(xbar);
free(ftmp);
free(btmp);
PyObject* pyCostTable = PyList_New(mmax + 1);
PyObject* pyBackPtr = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
for (long i = 0; i <= chainLength; ++i) {
PyObject* pyCostTable_m_i = PyDict_New();
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
PyObject* pyBackPtr_m_i = PyDict_New();
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
for (long l = i; l <= chainLength; ++l) {
PyObject* pyVar_l = PyLong_FromLong(l);
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l;
if (BACK_PTR(m, i, l) < 0)
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
else
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l);
}
}
}
free(costTable);
free(backPtr);
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
Py_DECREF(pyCostTable);
Py_DECREF(pyBackPtr);
return result;
}
static PyMethodDef rotorMethods[] = {
{"compute_table", computeTable, METH_VARARGS,
"Compute the optimal table with the rotor algorithm."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef rotorModule = {
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
"A simple implementation of dynamic programming algorithm rotor with C in "
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
rotorMethods};
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }

View File

@ -1,5 +1,5 @@
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Tuple from typing import Any, Dict, List, Tuple
from torch import Tensor from torch import Tensor
from torch.fx import Graph, Node from torch.fx import Graph, Node
@ -15,9 +15,9 @@ from colossalai.fx.profiler import (
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverBase'] __all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase): class CheckpointSolverRotor(CheckpointSolverBase):
@ -59,11 +59,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
self.back_ptr = None self.back_ptr = None
self.sequence = None self.sequence = None
def solve(self, force_python: bool = False) -> Graph: def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm. """Solve the checkpointing problem using rotor algorithm.
Args: Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns: Returns:
graph (Graph): The optimized graph, should be a copy of the original graph. graph (Graph): The optimized graph, should be a copy of the original graph.
@ -76,14 +77,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
else: else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
if verbose:
self.print_chain()
# backtrack # backtrack
try: try:
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr) self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list) self._annotate_from_sequence(self.sequence, self.node_list)
except RuntimeError as e: except ValueError as e:
# using logger to annonce that the solver is failed # using logger to annonce that the solver is failed
logger = get_dist_logger() logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}') logger.warning(f'Checkpoint solver failed: {e}')
raise ValueError
if verbose:
self.print_sequence()
return deepcopy(self.graph) return deepcopy(self.graph)
@ -100,42 +109,42 @@ class CheckpointSolverRotor(CheckpointSolverBase):
@classmethod @classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph) input_tensors = cls._extract_input(graph)
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list() ftime, btime, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for idx, node in enumerate(node_list): for node in node_list:
node_info = cls._extract_node_info(node) node_info = cls._extract_node_info(node)
fwd_time.append(node_info[0]) ftime.append(node_info[0])
bwd_time.append(node_info[1]) btime.append(node_info[1])
x.append(node_info[2]) x.append(node_info[2])
xbar.append(node_info[3]) xbar.append(node_info[3])
ftmp.append(node_info[4]) ftmp.append(node_info[4])
btmp.append(node_info[5]) btmp.append(node_info[5])
# currently we view loss backward temp as zero # currently we view loss backward temp as zero
bwd_time.append(0) btime.append(0)
btmp.append(0) btmp.append(0)
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp) return Chain(ftime, btime, x, xbar, ftmp, btmp)
@classmethod @classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes""" """Extract node info from a list of nodes"""
xbar = 0 xbar = 0
fwd_time = 0 ftime = 0
bwd_time = 0 btime = 0
for n in node: for n in node:
assert isinstance(n, Node), f'{n} is not a Node' assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
# minimum flop count is required # minimum flop count is required
fwd_time += max(calculate_fwd_time(n), 1.0) ftime += max(calculate_fwd_time(n), 1.0)
bwd_time += max(calculate_bwd_time(n), 1.0) btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1]) x = calculate_fwd_out(node[-1])
xbar = max(x, xbar) xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node) ftmp = cls._extract_ftmp(node)
btmp = cls._extract_btmp(node) btmp = cls._extract_btmp(node)
return fwd_time, bwd_time, x, xbar, ftmp, btmp return ftime, btime, x, xbar, ftmp, btmp
@staticmethod @staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
@ -180,17 +189,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return btmp return btmp
@staticmethod @staticmethod
def _compute_table(chain: Chain, mem_slots: int) -> Tuple: def _compute_table(chain: Chain, mmax: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. """Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args: Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem. chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mem_slots (int): Number of slots for discretizing memory budget. mmax (int): Maximum number of memory slots.
Returns: Returns:
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j of length j
""" """
@ -203,13 +212,13 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btmp = chain.btmp + [0] btmp = chain.btmp + [0]
# Build table # Build table
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)] cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)] back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0 # Initialize borders of the tables for lmax-lmin = 0
for m in range(mem_slots + 1): for m in range(mmax + 1):
for i in range(chain.length + 1): for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1) if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i] cost_table[m][i][i] = ftime[i] + btime[i]
@ -217,9 +226,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
cost_table[m][i][i] = float("inf") cost_table[m][i][i] = float("inf")
# Compute everything # Compute everything
for m in range(mem_slots + 1): for m in range(mmax + 1):
for d in range(1, chain.length + 1): for d in range(1, len(chain) + 1):
for i in range(chain.length + 1 - d): for i in range(len(chain) + 1 - d):
idx = i + d idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i] mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1: if idx > i + 1:
@ -248,20 +257,46 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return cost_table, back_ptr return cost_table, back_ptr
@staticmethod @staticmethod
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple: def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
raise NotImplementedError("C implementation not available yet") try:
from .rotorc import compute_table
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]], # build module if module not found
back_ptr: List[List[Dict[int, int]]]) -> List[int]: except ModuleNotFoundError:
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("rotorc has been built!", ranks=[0])
from .rotorc import compute_table
else:
logger.warning("rotorc built failed! Using python version!", ranks=[0])
return CheckpointSolverRotor._compute_table(chain, mmax)
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy. """Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args: Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem. chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lmin (int): The left index of the interval to backtrack. lhs (int): The left index of the interval to backtrack.
lmax (int): The right index of the interval to backtrack. rhs (int): The right index of the interval to backtrack.
mem_budget (int): The memory budget for processing this interval. budget (int): The memory budget for processing this interval.
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions back_ptr (List[Any]): See `._compute_table()` for definitions
Raises: Raises:
ValueError: Can not process the chain. ValueError: Can not process the chain.
@ -269,36 +304,45 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Returns: Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints. sequence (Sequence): The sequence of executing nodes with checkpoints.
""" """
if mem_budget <= 0: if budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {mem_budget}") raise ValueError(f"Can not process a chain with negative memory {budget}")
elif cost_table[mem_budget][lmin][lmax] == float("inf"): elif cost_table[budget][lhs][rhs] == float("inf"):
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}") raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget)) sequence = Sequence()
if lmin == lmax: if rhs == lhs:
if lmin == chain.length: if lhs == len(chain):
sequence.insert(Loss()) sequence += [Loss()]
else: else:
sequence.insert(ForwardEnable(lmin)) sequence += [ForwardEnable(lhs), Backward(lhs)]
sequence.insert(Backward(lmin))
return sequence return sequence
if back_ptr[mem_budget][lmin][lmax][0]: if back_ptr[budget][lhs][rhs][0]:
sequence.insert(ForwardEnable(lmin)) sequence += [
sequence.insert_sequence( ForwardEnable(lhs),
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr)) CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
sequence.insert(Backward(lmin)) back_ptr),
Backward(lhs),
]
else: else:
j = back_ptr[mem_budget][lmin][lmax][1] best_leaf = back_ptr[budget][lhs][rhs][1]
sequence.insert(ForwardCheck(lmin)) sequence += [ForwardCheck(lhs)]
for k in range(lmin + 1, j): sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence.insert(ForwardNograd(k)) sequence += [
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr)) CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr)) back_ptr),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence return sequence
@staticmethod @staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list = sequence.list_operations() op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss)) loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)] fwd_list = op_list[:op_list.index(loss_op)]

View File

@ -1,6 +1,6 @@
import math import math
from abc import ABC from abc import ABC
from typing import List from typing import Any, Iterable, List
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -33,23 +33,25 @@ class Chain:
self.xbar = xbar self.xbar = xbar
self.ftmp = ftmp self.ftmp = ftmp
self.btmp = btmp self.btmp = btmp
self.length = len(ftime)
if check_consistency and not self.check_lengths(): if check_consistency and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths") raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self): def check_lengths(self):
return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1) return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length) and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1)) and (len(self.xbar) == len(self) + 1))
def __repr__(self): def __repr__(self):
chain_list = [] chain_list = []
for i in range(self.length): for i in range(len(self)):
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i])) chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
i = self.length i = len(self)
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i])) chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
return chain_list.__repr__() return chain_list.__repr__()
def __len__(self):
return len(self.ftime)
def discretize_all(self, unit: int): def discretize_all(self, unit: int):
"""Discretize the chain into a list of chains according to unit size.""" """Discretize the chain into a list of chains according to unit size."""
discretizer = lambda val: math.ceil(val / unit) discretizer = lambda val: math.ceil(val / unit)
@ -163,79 +165,20 @@ class DiscardMemory(MemoryAccess):
name = "DM" name = "DM"
class Function: class Sequence(list):
def __init__(self, name, *args): def __init__(self):
self.name = name super().__init__()
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self): def __repr__(self):
return repr(self.list_operations()) return repr(self.list_operations())
def list_operations(self): def list_operations(self):
op_list = [] op_list = []
for x in self.sequence: for x in self:
if isinstance(x, Operation): if isinstance(x, Operation):
op_list.append(x) op_list.append(x)
else: else:
assert isinstance(x, Sequence) assert isinstance(x, Sequence)
op_list += x.list_operations() op_list += x.list_operations()
return op_list return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)

View File

@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
kwargs['inplace'] = True kwargs['inplace'] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False do_not_cache = False
meta.bwd_mem_out -= param_size meta.bwd_mem_out -= param_size
@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
module.inplace = True module.inplace = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False do_not_cache = False
# grad for param will not be counted # grad for param will not be counted