mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
439 lines
18 KiB
439 lines
18 KiB
from copy import deepcopy |
|
from typing import Any, Dict, List, Tuple |
|
|
|
from torch import Tensor |
|
from torch.fx import Graph, Node |
|
|
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply |
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions |
|
from colossalai.fx.profiler import ( |
|
activation_size, |
|
calculate_bwd_time, |
|
calculate_fwd_out, |
|
calculate_fwd_time, |
|
calculate_fwd_tmp, |
|
) |
|
from colossalai.logging import get_dist_logger |
|
|
|
from .ckpt_solver_base import CheckpointSolverBase |
|
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence |
|
|
|
__all__ = ['CheckpointSolverRotor'] |
|
|
|
|
|
class CheckpointSolverRotor(CheckpointSolverBase): |
|
|
|
def __init__(self, |
|
graph: Graph, |
|
free_memory: float = -1, |
|
cnode: List[str] = None, |
|
memory_slots: int = 500, |
|
optim_multiplier: float = 1.0): |
|
"""This is the simple implementation of dynamic programming algorithm rotor |
|
in https://hal.inria.fr/hal-02352969. Some code are adapted from |
|
https://gitlab.inria.fr/hiepacs/rotor. |
|
|
|
Usage: |
|
Assume that we have a ``GraphModule``, and we have already done the extractions |
|
to the graph to retrieve all information needed, then we could use the following |
|
code to find a solution using ``CheckpointSolverRotor``: |
|
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) |
|
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver |
|
>>> gm.graph = rotor_graph # set the graph to a new graph |
|
|
|
Args: |
|
graph (Graph): The computing graph to be optimized. |
|
free_memory (float, optional): Memory constraint for the solution, unit is byte. |
|
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. |
|
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. |
|
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. |
|
optim_multiplier (float, optional): The multiplier of extra weight storage for the |
|
``torch.optim.Optimizer``. Default to 1.0. |
|
""" |
|
super().__init__(graph, free_memory, True, cnode, optim_multiplier) |
|
self.memory_slots = memory_slots |
|
|
|
# construct chain |
|
unit = self.free_memory // self.memory_slots |
|
self.chain = self._construct_chain(self.graph, self.node_list) |
|
self.chain.discretize_all(unit) |
|
|
|
self.cost_table = None |
|
self.back_ptr = None |
|
self.sequence = None |
|
|
|
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: |
|
"""Solve the checkpointing problem using rotor algorithm. |
|
|
|
Args: |
|
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False. |
|
verbose (bool, optional): Print verbose information. Defaults to False. |
|
|
|
Returns: |
|
graph (Graph): The optimized graph, should be a copy of the original graph. |
|
""" |
|
chain = self.chain |
|
|
|
# compute cost table |
|
if force_python: |
|
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots) |
|
else: |
|
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots) |
|
|
|
if verbose: |
|
self.print_chain() |
|
|
|
# backtrack |
|
try: |
|
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, |
|
self.back_ptr) |
|
self._annotate_from_sequence(self.sequence, self.node_list) |
|
except ValueError as e: |
|
# using logger to annonce that the solver is failed |
|
logger = get_dist_logger() |
|
logger.warning(f'Checkpoint solver failed: {e}') |
|
raise ValueError |
|
|
|
if verbose: |
|
self.print_sequence() |
|
|
|
return deepcopy(self.graph) |
|
|
|
def print_chain(self): |
|
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) |
|
for idx in range(len(self.node_list) - 1): |
|
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], |
|
self.chain.btmp[idx]) |
|
print(f'Chain = {self.chain}') |
|
|
|
def print_sequence(self): |
|
print(f'Sequence = {self.sequence}') |
|
|
|
@classmethod |
|
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: |
|
input_tensors = cls._extract_input(graph) |
|
ftime, btime, ftmp, btmp = list(), list(), list(), list() |
|
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)] |
|
|
|
for node in node_list: |
|
node_info = cls._extract_node_info(node) |
|
ftime.append(node_info[0]) |
|
btime.append(node_info[1]) |
|
x.append(node_info[2]) |
|
xbar.append(node_info[3]) |
|
ftmp.append(node_info[4]) |
|
btmp.append(node_info[5]) |
|
|
|
# currently we view loss backward temp as zero |
|
btime.append(0) |
|
btmp.append(0) |
|
|
|
return Chain(ftime, btime, x, xbar, ftmp, btmp) |
|
|
|
@classmethod |
|
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: |
|
"""Extract node info from a list of nodes""" |
|
xbar = 0 |
|
ftime = 0 |
|
btime = 0 |
|
fwd_mem_peak = 0 |
|
for n in node: |
|
assert isinstance(n, Node), f'{n} is not a Node' |
|
if n.target == runtime_apply or n.target == runtime_comm_spec_apply: |
|
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta |
|
xbar += n.meta['fwd_mem_out'] |
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) |
|
else: |
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) |
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) |
|
|
|
# minimum flop count is required |
|
ftime += max(calculate_fwd_time(n), 1.0) |
|
btime += max(calculate_bwd_time(n), 1.0) |
|
|
|
x = calculate_fwd_out(node[-1]) |
|
xbar = max(x, xbar) |
|
ftmp = fwd_mem_peak - xbar |
|
btmp = cls._extract_btmp(node) |
|
return ftime, btime, x, xbar, ftmp, btmp |
|
|
|
@staticmethod |
|
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: |
|
"""Extract input tensors from a Graph""" |
|
input_tensors = [] |
|
for node in graph.nodes: |
|
if node.op == 'placeholder': |
|
input_tensors.append(node.meta['fwd_out']) |
|
return input_tensors |
|
|
|
@staticmethod |
|
def _extract_unused_output(node: Node) -> int: |
|
"""Extract unused output from `torch.fx.Node`""" |
|
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) |
|
|
|
@staticmethod |
|
def _extract_btmp(node: List[Node]) -> int: |
|
"""Extract btmp from a list of nodes""" |
|
|
|
def _extract_deps_size(): |
|
deps_size = 0 |
|
for k, v in deps.items(): |
|
k: Node |
|
if v > 0: |
|
deps_size += k.meta['bwd_mem_out'] |
|
if v == float('-inf'): |
|
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) |
|
|
|
return deps_size |
|
|
|
btmp = 0 |
|
deps = {} |
|
for n in reversed(node): |
|
deps[n] = len(n.all_input_nodes) |
|
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) |
|
for child in n.users: |
|
if child in deps: |
|
deps[child] -= 1 |
|
if deps[child] <= 0: |
|
deps[child] = float('-inf') # free |
|
return btmp |
|
|
|
@staticmethod |
|
def _compute_table(chain: Chain, mmax: int) -> Tuple: |
|
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer. |
|
|
|
Args: |
|
chain (Chain): A basic linearized structure for solving the dynamic programming problem. |
|
mmax (int): Maximum number of memory slots. |
|
|
|
Returns: |
|
cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs |
|
with m memory slots. |
|
back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice |
|
is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j |
|
""" |
|
|
|
ftime = chain.ftime + [0.0] |
|
btime = chain.btime |
|
x = chain.x + [0] |
|
xbar = chain.xbar + [0] |
|
ftmp = chain.ftmp + [0] |
|
btmp = chain.btmp + [0] |
|
|
|
# Build table |
|
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] |
|
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] |
|
|
|
# Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs |
|
for m in range(mmax + 1): |
|
for i in range(len(chain) + 1): |
|
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) |
|
if m >= limit: |
|
cost_table[m][i][i] = ftime[i] + btime[i] |
|
else: |
|
cost_table[m][i][i] = float("inf") |
|
|
|
# Compute tables |
|
for m in range(mmax + 1): |
|
for d in range(1, len(chain) + 1): |
|
for i in range(len(chain) + 1 - d): |
|
idx = i + d |
|
mmin = x[idx + 1] + x[i + 1] + ftmp[i] |
|
if idx > i + 1: |
|
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx))) |
|
if m < mmin: |
|
cost_table[m][i][idx] = float("inf") |
|
else: |
|
leaf_checkpoints = [(j, |
|
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) |
|
for j in range(i + 1, idx + 1) |
|
if m >= x[j]] |
|
if leaf_checkpoints: |
|
best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) |
|
else: |
|
best_leaf = None |
|
if m >= xbar[i + 1]: |
|
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx] |
|
else: |
|
chain_checkpoint = float("inf") |
|
if best_leaf and best_leaf[1] <= chain_checkpoint: |
|
cost_table[m][i][idx] = best_leaf[1] |
|
back_ptr[m][i][idx] = (False, best_leaf[0]) |
|
else: |
|
cost_table[m][i][idx] = chain_checkpoint |
|
back_ptr[m][i][idx] = (True,) |
|
return cost_table, back_ptr |
|
|
|
@staticmethod |
|
def _compute_table_c(chain: Chain, mmax: int) -> Tuple: |
|
try: |
|
from .rotorc import compute_table |
|
|
|
# build module if module not found |
|
except ModuleNotFoundError: |
|
import os |
|
import subprocess |
|
import sys |
|
logger = get_dist_logger() |
|
logger.info("rotorc hasn't been built! Building library...", ranks=[0]) |
|
this_dir = os.path.dirname(os.path.abspath(__file__)) |
|
result = subprocess.Popen( |
|
[ |
|
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", |
|
f"--build-lib={this_dir}" |
|
], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE, |
|
) |
|
if result.wait() == 0: |
|
logger.info("rotorc has been built!", ranks=[0]) |
|
from .rotorc import compute_table |
|
else: |
|
logger.warning("rotorc built failed! Using python version!", ranks=[0]) |
|
return CheckpointSolverRotor._compute_table(chain, mmax) |
|
return compute_table(chain, mmax) |
|
|
|
@staticmethod |
|
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], |
|
back_ptr: List[Any]) -> "Sequence": |
|
"""Backtrack the cost table and retrieve the optimal checkpointing strategy. |
|
|
|
Args: |
|
chain (Chain): A basic linearized structure for solving the dynamic programming problem. |
|
lhs (int): The left index of the interval to backtrack. |
|
rhs (int): The right index of the interval to backtrack. |
|
budget (int): The memory budget for processing this interval. |
|
cost_table (List[Any]): See ``._compute_table()`` for definitions |
|
back_ptr (List[Any]): See ``._compute_table()`` for definitions |
|
|
|
Raises: |
|
ValueError: Can not process the chain. |
|
|
|
Returns: |
|
sequence (Sequence): The sequence of executing nodes with checkpoints. |
|
""" |
|
if budget <= 0: |
|
raise ValueError(f"Can not process a chain with negative memory {budget}") |
|
elif cost_table[budget][lhs][rhs] == float("inf"): |
|
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}") |
|
|
|
sequence = Sequence() |
|
if rhs == lhs: |
|
if lhs == len(chain): |
|
sequence += [Loss()] |
|
else: |
|
sequence += [ForwardEnable(lhs), Backward(lhs)] |
|
return sequence |
|
|
|
if back_ptr[budget][lhs][rhs][0]: |
|
sequence += [ |
|
ForwardEnable(lhs), |
|
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, |
|
back_ptr), |
|
Backward(lhs), |
|
] |
|
else: |
|
best_leaf = back_ptr[budget][lhs][rhs][1] |
|
sequence += [ForwardCheck(lhs)] |
|
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] |
|
sequence += [ |
|
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, |
|
back_ptr), |
|
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), |
|
] |
|
return sequence |
|
|
|
@staticmethod |
|
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): |
|
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. |
|
|
|
Args: |
|
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. |
|
node_list (List[List[Node]]): The list of nodes to annotate. |
|
""" |
|
op_list = sequence.list_operations() |
|
loss_op = next(op for op in op_list if isinstance(op, Loss)) |
|
fwd_list = op_list[:op_list.index(loss_op)] |
|
bwd_list = op_list[op_list.index(loss_op) + 1:] |
|
ckpt_idx = 0 |
|
in_ckpt = False |
|
ckpt_region = [] |
|
|
|
# forward annotation |
|
for idx, op in enumerate(fwd_list, 0): |
|
if in_ckpt: |
|
if isinstance(op, ForwardNograd): |
|
ckpt_region.append(idx) |
|
|
|
elif isinstance(op, ForwardEnable): |
|
in_ckpt = False |
|
for node_idx in ckpt_region: |
|
for n in node_list[node_idx]: |
|
n.meta['activation_checkpoint'] = [ckpt_idx] |
|
|
|
ckpt_idx += 1 |
|
ckpt_region = [] |
|
|
|
elif isinstance(op, ForwardCheck): |
|
for node_idx in ckpt_region: |
|
for n in node_list[node_idx]: |
|
n.meta['activation_checkpoint'] = [ckpt_idx] |
|
|
|
ckpt_idx += 1 |
|
ckpt_region = [idx] |
|
|
|
else: |
|
if isinstance(op, ForwardCheck): |
|
in_ckpt = True |
|
ckpt_region.append(idx) |
|
|
|
# annotate the backward if there is any nested activation checkpoint |
|
in_recompute = False |
|
for op in bwd_list: |
|
if in_recompute: |
|
if isinstance(op, ForwardNograd): |
|
ckpt_region.append(op.index) |
|
|
|
elif isinstance(op, ForwardEnable): |
|
for node_idx in ckpt_region: |
|
for n in node_list[node_idx]: |
|
n.meta['activation_checkpoint'].append(ckpt_idx) |
|
|
|
ckpt_idx += 1 |
|
ckpt_region = [] |
|
|
|
elif isinstance(op, ForwardCheck): |
|
for node_idx in ckpt_region: |
|
for n in node_list[node_idx]: |
|
n.meta['activation_checkpoint'].append(ckpt_idx) |
|
|
|
ckpt_idx += 1 |
|
ckpt_region = [op.index] |
|
|
|
elif isinstance(op, Backward): |
|
for node_idx in ckpt_region: |
|
for n in node_list[node_idx]: |
|
n.meta['activation_checkpoint'].append(ckpt_idx) |
|
|
|
in_recompute = False |
|
|
|
else: |
|
if not isinstance(op, Backward): |
|
in_recompute = True |
|
ckpt_idx = 0 |
|
ckpt_region = [] |
|
if isinstance(op, ForwardCheck): |
|
ckpt_region.append(op.index) |
|
|
|
# postprocess, make sure every activation checkpoint label in the |
|
# same activation checkpoint region (level = 0) has the same length |
|
op_list = [] |
|
for node in node_list: |
|
op_list += node |
|
ckpt_regions = _find_nested_ckpt_regions(op_list) |
|
for (start_idx, end_idx) in ckpt_regions: |
|
nested_length = max( |
|
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) |
|
for idx in range(start_idx, end_idx + 1): |
|
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - |
|
len(op_list[idx].meta['activation_checkpoint']))
|
|
|