ColossalAI/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py

427 lines
17 KiB
Python
Raw Normal View History

from copy import deepcopy
from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
calculate_bwd_time,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float, optional): Memory constraint for the solution, unit is byte.
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
"""
super().__init__(graph, free_memory, True, cnode)
self.memory_slots = memory_slots
# construct chain
unit = self.free_memory // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit)
self.cost_table = None
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain = self.chain
# compute cost table
if force_python:
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
if verbose:
self.print_chain()
# backtrack
try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
raise ValueError
if verbose:
self.print_sequence()
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
def print_sequence(self):
print(f'Sequence = {self.sequence}')
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
ftime, btime, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for node in node_list:
node_info = cls._extract_node_info(node)
ftime.append(node_info[0])
btime.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
btime.append(0)
btmp.append(0)
return Chain(ftime, btime, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
ftime = 0
btime = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
return input_tensors
@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
"""Extract btmp from a list of nodes"""
def _extract_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
btmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return btmp
@staticmethod
def _compute_table(chain: Chain, mmax: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime = chain.ftime + [0.0]
btime = chain.btime
x = chain.x + [0]
xbar = chain.xbar + [0]
ftmp = chain.ftmp + [0]
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= xbar[i + 1]:
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
cost_table[m][i][idx] = best_leaf[1]
back_ptr[m][i][idx] = (False, best_leaf[0])
else:
cost_table[m][i][idx] = chain_checkpoint
back_ptr[m][i][idx] = (True,)
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
try:
from .rotorc import compute_table
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("rotorc has been built!", ranks=[0])
from .rotorc import compute_table
else:
logger.warning("rotorc built failed! Using python version!", ranks=[0])
return CheckpointSolverRotor._compute_table(chain, mmax)
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {budget}")
elif cost_table[budget][lhs][rhs] == float("inf"):
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
sequence = Sequence()
if rhs == lhs:
if lhs == len(chain):
sequence += [Loss()]
else:
sequence += [ForwardEnable(lhs), Backward(lhs)]
return sequence
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
Backward(lhs),
]
else:
best_leaf = back_ptr[budget][lhs][rhs][1]
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))