[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
pull/1789/head
Super Daniel 2022-11-01 10:43:15 +08:00 committed by GitHub
parent 2b859502d5
commit 1e88811c7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1025 additions and 119 deletions

View File

@ -0,0 +1,3 @@
from .ckpt_solver_base import CheckpointSolverBase
from .ckpt_solver_chen import CheckpointSolverChen
from .ckpt_solver_rotor import CheckpointSolverRotor

View File

@ -0,0 +1,167 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, List
from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace
__all___ = ['CheckpointSolverBase']
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
n_dst.meta = n_src.meta
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
memory_budget: float = -1.0,
parameter_size: float = 0,
requires_linearize: bool = False,
cnode: List[str] = None,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
Args:
graph (Graph): The computing graph to be optimized.
memory_budget (float): Memory constraint for the solution.
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
# the solver is executed.
self.graph = deepcopy(graph)
self.graph.owning_module = graph.owning_module
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if `MetaInfoProp` is done
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
self.memory_budget = memory_budget
self.parameter_size = parameter_size
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
self.node_list = self._linearize_graph()
else:
self.node_list = self.get_node_list()
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
def get_node_list(self):
"""Get the node list.
"""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
deps = {}
node_list = []
region = []
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
node_list.append(region)
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return node_list

View File

@ -0,0 +1,87 @@
import math
from copy import deepcopy
from typing import List, Set, Tuple
from torch.fx import Graph, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
"""
super().__init__(graph, 0, 0, True, cnode)
self.num_grids = num_grids
def solve(self) -> Graph:
"""Solve the checkpointing problem using Algorithm 3.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for idx, nodes in enumerate(self.node_list):
for n in nodes:
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and idx > prev_idx:
x += calculate_fwd_in(nodes[0])
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
ckpt_intv, b_approx = self.run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt

View File

@ -0,0 +1,387 @@
from copy import deepcopy
from typing import Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
calculate_bwd_time,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
__all__ = ['CheckpointSolverBase']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
graph: Graph,
memory_budget: float = -1,
parameter_size: float = 0,
cnode: List[str] = None,
memory_slots: int = 500):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size)
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
memory_budget (float, optional): Memory constraint for the solution, unit is byte.
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
"""
super().__init__(graph, memory_budget, parameter_size, True, cnode)
self.memory_slots = memory_slots
# construct chain
unit = self.memory_budget // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit)
self.cost_table = None
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain = self.chain
# compute cost table
if force_python:
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
# backtrack
try:
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except RuntimeError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
def print_sequence(self):
print(f'Sequence = {self.sequence}')
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for idx, node in enumerate(node_list):
node_info = cls._extract_node_info(node)
fwd_time.append(node_info[0])
bwd_time.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
bwd_time.append(0)
btmp.append(0)
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
fwd_time = 0
bwd_time = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
# minimum flop count is required
fwd_time += max(calculate_fwd_time(n), 1.0)
bwd_time += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
btmp = cls._extract_btmp(node)
return fwd_time, bwd_time, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
return input_tensors
@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
"""Extract btmp from a list of nodes"""
def _extract_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
btmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return btmp
@staticmethod
def _compute_table(chain: Chain, mem_slots: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mem_slots (int): Number of slots for discretizing memory budget.
Returns:
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime = chain.ftime + [0.0]
btime = chain.btime
x = chain.x + [0]
xbar = chain.xbar + [0]
ftmp = chain.ftmp + [0]
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mem_slots + 1):
for i in range(chain.length + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mem_slots + 1):
for d in range(1, chain.length + 1):
for i in range(chain.length + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= xbar[i + 1]:
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
cost_table[m][i][idx] = best_leaf[1]
back_ptr[m][i][idx] = (False, best_leaf[0])
else:
cost_table[m][i][idx] = chain_checkpoint
back_ptr[m][i][idx] = (True,)
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple:
raise NotImplementedError("C implementation not available yet")
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]],
back_ptr: List[List[Dict[int, int]]]) -> List[int]:
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lmin (int): The left index of the interval to backtrack.
lmax (int): The right index of the interval to backtrack.
mem_budget (int): The memory budget for processing this interval.
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if mem_budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {mem_budget}")
elif cost_table[mem_budget][lmin][lmax] == float("inf"):
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}")
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget))
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
else:
sequence.insert(ForwardEnable(lmin))
sequence.insert(Backward(lmin))
return sequence
if back_ptr[mem_budget][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr))
sequence.insert(Backward(lmin))
else:
j = back_ptr[mem_budget][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k))
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr))
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr))
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))

View File

@ -0,0 +1,241 @@
import math
from abc import ABC
from typing import List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
Args:
ftime (List[float]): The forward time of each node.
btime (List[float]): The backward time of each node.
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
ftmp (List[int]): The temporary forward memory of each node.
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
"""
self.ftime = ftime
self.btime = btime
self.x = x
self.xbar = xbar
self.ftmp = ftmp
self.btmp = btmp
self.length = len(ftime)
if check_consistency and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1)
and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length)
and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
i = self.length
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
return chain_list.__repr__()
def discretize_all(self, unit: int):
"""Discretize the chain into a list of chains according to unit size."""
discretizer = lambda val: math.ceil(val / unit)
self.x = tree_map(discretizer, self.x)
self.xbar = tree_map(discretizer, self.xbar)
self.ftmp = tree_map(discretizer, self.ftmp)
self.btmp = tree_map(discretizer, self.btmp)
class Operation(ABC):
name = "Op"
def __repr__(self) -> str:
return f"{self.name}_{self.index}"
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Forward(Operation):
name = "F"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.ftime[self.index]
else:
return 1
class ForwardEnable(Forward):
name = "Fe"
class ForwardNograd(Forward):
name = "Fn"
class ForwardCheck(Forward):
name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
name = "B"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.btime[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
name = "MA"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
name = "WM"
class ReadMemory(MemoryAccess):
name = "RM"
class DiscardMemory(MemoryAccess):
name = "DM"
class Function:
def __init__(self, name, *args):
self.name = name
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self.sequence:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)

View File

@ -1,14 +1,37 @@
import colossalai
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from typing import List, Callable, Any, Tuple, Dict, Iterable
import colossalai
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import (
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_args,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
act_ckpt_label = node.meta['activation_checkpoint']
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
elif current_region is not None and not 'activation_checkpoint' in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
act_offload_label = node.activation_offload
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
act_offload_label = node.meta['activation_offload']
if current_region == None:
current_region = act_offload_label
@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
"""Check if the node could end the ckpt region
Args:
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
check_idx (int): the index of checkpoint level for
nested checkpoint
Returns:
bool
"""
if hasattr(node, "activation_checkpoint"):
if isinstance(node.activation_checkpoint, list):
return node.activation_checkpoint[check_idx] == None
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], list):
return node.meta['activation_checkpoint'][check_idx] == None
else:
return False
else:
@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
def _find_nested_ckpt_regions(nodes, check_idx=0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
if isinstance(getattr(node, 'activation_checkpoint'), int):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], int):
act_ckpt_label = node.meta['activation_checkpoint']
else:
act_ckpt_label = node.activation_checkpoint[check_idx]
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@ -287,7 +306,6 @@ def emit_ckpt_func(body,
level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
functions code
@ -303,8 +321,8 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].activation_checkpoint, int):
label = node_list[0].activation_checkpoint
if isinstance(node_list[0].meta['activation_checkpoint'], int):
label = node_list[0].meta['activation_checkpoint']
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
for node in node_list:
@ -313,7 +331,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
@ -322,12 +340,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if level + 1 < len(node_list[0].activation_checkpoint):
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@ -354,7 +372,7 @@ def emit_ckpt_func(body,
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
@ -368,7 +386,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
if 'activation_offload' in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload']
else:
activation_offload = False
@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
if hasattr(user, "activation_checkpoint"):
if user.activation_checkpoint == label:
if 'activation_checkpoint' in user.meta:
if user.meta['activation_checkpoint'] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
@ -851,10 +865,8 @@ else:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
@ -999,7 +1011,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
@ -1040,7 +1052,6 @@ else:
# in forward function
fn_code = f"""
{wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
{code}"""

View File

@ -13,10 +13,10 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size
int: The activation size, unit is byte.
"""
act_size = 0
if isinstance(out, torch.Tensor):
@ -38,10 +38,10 @@ def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size
int: The parameter size, unit is byte.
"""
param_size = 0
for param in mod.parameters():

View File

@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def pack(x):
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.uuid)
cache.add(x._tensor.data_ptr())
return x
def unpack(x):
@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
return tensor
if not isinstance(x, torch.finfo):
return x

View File

@ -87,8 +87,8 @@ def calculate_fwd_out(n: Node) -> int:
fwd_in = dict()
for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
return activation_size(intersect(fwd_in, fwd_out))

View File

@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
__all__ = ['MetaTensor']
def set_uuid(x):
def set_data_ptr(x):
if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'):
setattr(x, 'uuid', uuid.uuid4())
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
@compatibility(is_backward_compatible=False)
@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
set_uuid(r._tensor)
set_data_ptr(r._tensor)
return r
def __repr__(self):
@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
setattr(out, 'uuid', args[0].uuid)
out.data_ptr = args[0].data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass

View File

@ -1,26 +1,28 @@
#!/usr/bin/env python
"""
tracer.py:
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import enum
import inspect
import functools
import inspect
import operator
from contextlib import contextmanager
from colossalai.fx.tracer.meta_patch import meta_patched_module
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.fx import Tracer, Node
from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from torch.fx import Node, Tracer
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
from torch.fx.proxy import ParameterProxy, Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_module
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
__all__ = ['ColoTracer']
@ -231,7 +233,7 @@ class ColoTracer(Tracer):
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
@ -383,7 +385,7 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node

View File

@ -2,11 +2,13 @@ import copy
import re
from typing import Callable
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
@ -14,7 +16,6 @@ from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@ -94,6 +95,7 @@ def _run_ckpt_solver(rank):
gpc.destroy()
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)

View File

@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()
@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()
@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@ -1,14 +1,16 @@
import copy
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()

View File

@ -1,9 +1,10 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils.checkpoint import checkpoint
from colossalai.fx import ColoTracer
class MLP(torch.nn.Module):
@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
for node in gm.graph.nodes:
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 0
assert node.meta.get('activation_checkpoint', -1) == 0
for node in gm.graph.nodes:
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 1
assert node.meta.get('activation_checkpoint', -1) == 1
tracer = ColoTracer(trace_act_ckpt=False)
graph = tracer.trace(module)