|
|
|
import math
|
|
|
|
from abc import ABC
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
|
|
|
|
class Chain:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
ftime: List[float],
|
|
|
|
btime: List[float],
|
|
|
|
x: List[int],
|
|
|
|
xbar: List[int],
|
|
|
|
ftmp: List[int],
|
|
|
|
btmp: List[int],
|
|
|
|
check_consistency: bool = True,
|
|
|
|
):
|
|
|
|
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
|
|
|
|
See paper https://hal.inria.fr/hal-02352969 for details.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ftime (List[float]): The forward time of each node.
|
|
|
|
btime (List[float]): The backward time of each node.
|
|
|
|
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
|
|
|
|
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
|
|
|
|
ftmp (List[int]): The temporary forward memory of each node.
|
|
|
|
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
|
|
|
|
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
|
|
|
|
"""
|
|
|
|
self.ftime = ftime
|
|
|
|
self.btime = btime
|
|
|
|
self.x = x
|
|
|
|
self.xbar = xbar
|
|
|
|
self.ftmp = ftmp
|
|
|
|
self.btmp = btmp
|
|
|
|
if check_consistency and not self.check_lengths():
|
|
|
|
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
|
|
|
|
|
|
|
def check_lengths(self):
|
|
|
|
return (
|
|
|
|
(len(self.ftime) == len(self))
|
|
|
|
and (len(self.btime) == len(self) + 1)
|
|
|
|
and (len(self.x) == len(self) + 1)
|
|
|
|
and (len(self.ftmp) == len(self))
|
|
|
|
and (len(self.btmp) == len(self) + 1)
|
|
|
|
and (len(self.xbar) == len(self) + 1)
|
|
|
|
)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
chain_list = []
|
|
|
|
for i in range(len(self)):
|
|
|
|
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
|
|
|
|
i = len(self)
|
|
|
|
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
|
|
|
|
return chain_list.__repr__()
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.ftime)
|
|
|
|
|
|
|
|
def discretize_all(self, unit: int):
|
|
|
|
"""Discretize the chain into a list of chains according to unit size."""
|
|
|
|
discretizer = lambda val: math.ceil(val / unit)
|
|
|
|
self.x = tree_map(discretizer, self.x)
|
|
|
|
self.xbar = tree_map(discretizer, self.xbar)
|
|
|
|
self.ftmp = tree_map(discretizer, self.ftmp)
|
|
|
|
self.btmp = tree_map(discretizer, self.btmp)
|
|
|
|
|
|
|
|
|
|
|
|
class Operation(ABC):
|
|
|
|
name = "Op"
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f"{self.name}_{self.index}"
|
|
|
|
|
|
|
|
def shift(self, value):
|
|
|
|
if type(self.index) is tuple:
|
|
|
|
self.index = tuple(x + value for x in self.index)
|
|
|
|
else:
|
|
|
|
self.index += value
|
|
|
|
|
|
|
|
|
|
|
|
class Forward(Operation):
|
|
|
|
name = "F"
|
|
|
|
|
|
|
|
def __init__(self, index):
|
|
|
|
self.index = index
|
|
|
|
|
|
|
|
def cost(self, chain: Chain):
|
|
|
|
if chain is not None:
|
|
|
|
return chain.ftime[self.index]
|
|
|
|
else:
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
class ForwardEnable(Forward):
|
|
|
|
name = "Fe"
|
|
|
|
|
|
|
|
|
|
|
|
class ForwardNograd(Forward):
|
|
|
|
name = "Fn"
|
|
|
|
|
|
|
|
|
|
|
|
class ForwardCheck(Forward):
|
|
|
|
name = "CF"
|
|
|
|
|
|
|
|
|
|
|
|
class Forwards(Operation):
|
|
|
|
def __init__(self, start, end):
|
|
|
|
self.index = (start, end)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
|
|
|
|
|
|
|
def cost(self, chain: Chain):
|
|
|
|
if chain is not None:
|
|
|
|
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
|
|
|
|
else:
|
|
|
|
return self.index[1] - self.index[0] + 1
|
|
|
|
|
|
|
|
|
|
|
|
def isForward(op):
|
|
|
|
return type(op) is Forward or type(op) is Forwards
|
|
|
|
|
|
|
|
|
|
|
|
class Backward(Operation):
|
|
|
|
name = "B"
|
|
|
|
|
|
|
|
def __init__(self, index):
|
|
|
|
self.index = index
|
|
|
|
|
|
|
|
def cost(self, chain: Chain):
|
|
|
|
if chain is not None:
|
|
|
|
return chain.btime[self.index]
|
|
|
|
else:
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
class Loss(Operation):
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "L"
|
|
|
|
|
|
|
|
def cost(self, chain):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryAccess(Operation):
|
|
|
|
name = "MA"
|
|
|
|
|
|
|
|
def __init__(self, index):
|
|
|
|
self.index = index
|
|
|
|
|
|
|
|
def cost(self, chain: Chain):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
class WriteMemory(MemoryAccess):
|
|
|
|
name = "WM"
|
|
|
|
|
|
|
|
|
|
|
|
class ReadMemory(MemoryAccess):
|
|
|
|
name = "RM"
|
|
|
|
|
|
|
|
|
|
|
|
class DiscardMemory(MemoryAccess):
|
|
|
|
name = "DM"
|
|
|
|
|
|
|
|
|
|
|
|
class Sequence(list):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return repr(self.list_operations())
|
|
|
|
|
|
|
|
def list_operations(self):
|
|
|
|
op_list = []
|
|
|
|
for x in self:
|
|
|
|
if isinstance(x, Operation):
|
|
|
|
op_list.append(x)
|
|
|
|
else:
|
|
|
|
assert isinstance(x, Sequence)
|
|
|
|
op_list += x.list_operations()
|
|
|
|
return op_list
|