mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.pull/1789/head
parent
2b859502d5
commit
1e88811c7a
|
@ -0,0 +1,3 @@
|
|||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .ckpt_solver_chen import CheckpointSolverChen
|
||||
from .ckpt_solver_rotor import CheckpointSolverRotor
|
|
@ -0,0 +1,167 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
from colossalai.fx.profiler.memory_utils import is_inplace
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
|
||||
|
||||
def _copy_output(src: Graph, dst: Graph):
|
||||
"""Copy the output node from src to dst"""
|
||||
for n_src, n_dst in zip(src.nodes, dst.nodes):
|
||||
if n_src.op == 'output':
|
||||
n_dst.meta = n_src.meta
|
||||
|
||||
|
||||
class CheckpointSolverBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
memory_budget: float = -1.0,
|
||||
parameter_size: float = 0,
|
||||
requires_linearize: bool = False,
|
||||
cnode: List[str] = None,
|
||||
):
|
||||
"""CheckpointSolver class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for
|
||||
target computing graph.
|
||||
|
||||
Existing Solvers:
|
||||
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
||||
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
memory_budget (float): Memory constraint for the solution.
|
||||
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
|
||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||
|
||||
Warnings:
|
||||
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
|
||||
"""
|
||||
# super-dainiu: this graph is a temporary graph which can refer to
|
||||
# the owning module, but we will return another deepcopy of it after
|
||||
# the solver is executed.
|
||||
self.graph = deepcopy(graph)
|
||||
self.graph.owning_module = graph.owning_module
|
||||
_copy_output(graph, self.graph)
|
||||
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
|
||||
# check if `MetaInfoProp` is done
|
||||
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
||||
|
||||
self.memory_budget = memory_budget
|
||||
self.parameter_size = parameter_size
|
||||
self.cnode = cnode
|
||||
self.requires_linearize = requires_linearize
|
||||
if self.requires_linearize:
|
||||
self.node_list = self._linearize_graph()
|
||||
else:
|
||||
self.node_list = self.get_node_list()
|
||||
|
||||
@abstractmethod
|
||||
def solve(self):
|
||||
"""Solve the checkpointing problem and return the solution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_node_list(self):
|
||||
"""Get the node list.
|
||||
"""
|
||||
return [[node] for node in self.graph.nodes]
|
||||
|
||||
def _linearize_graph(self) -> List[List[Node]]:
|
||||
"""Linearizing the graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
|
||||
Returns:
|
||||
List[List[Node]]: List of list, each inside list of Node presents
|
||||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
Do merge the inplace ops into the previous node.
|
||||
"""
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
# unchanged throughout the whole model, it will be used several times by
|
||||
# different blocks of model, so that it is hard for us to linearize the graph
|
||||
# when we encounter those kinds of nodes. We let users to annotate some of the
|
||||
# input as common node, such as attention mask, and the followings are some of
|
||||
# the ops that could actually be seen as common nodes. With our common node prop,
|
||||
# we could find some of the "real" common nodes (e.g. the real attention mask
|
||||
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
||||
# nodes or it's op belongs to the following operations, we view this node as a
|
||||
# newly born common node.
|
||||
# List of target name that could be seen as common node
|
||||
common_ops = ["getattr", "getitem", "size"]
|
||||
|
||||
def _is_cop(target: Any) -> bool:
|
||||
"""Check if an op could be seen as common node
|
||||
|
||||
Args:
|
||||
target (Any): node target
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
if isinstance(target, str):
|
||||
return target in common_ops
|
||||
else:
|
||||
return target.__name__ in common_ops
|
||||
|
||||
def _is_sink() -> bool:
|
||||
"""Check if we can free all dependencies
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
|
||||
else:
|
||||
self.cnode = []
|
||||
|
||||
deps = {}
|
||||
node_list = []
|
||||
region = []
|
||||
|
||||
for n in self.graph.nodes:
|
||||
if n.op != "placeholder" and n.op != "output":
|
||||
for n_par in n.all_input_nodes:
|
||||
if n_par.op != "placeholder" and n_par.name not in self.cnode:
|
||||
deps[n_par] -= 1
|
||||
region.append(n)
|
||||
|
||||
# if the node could free all dependencies in graph
|
||||
# we could begin a new node
|
||||
if _is_sink():
|
||||
node_list.append(region)
|
||||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
return node_list
|
|
@ -0,0 +1,87 @@
|
|||
import math
|
||||
from copy import deepcopy
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
|
||||
__all__ = ['CheckpointSolverChen']
|
||||
|
||||
|
||||
class CheckpointSolverChen(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverChen`:
|
||||
>>> solver = CheckpointSolverChen(gm.graph)
|
||||
>>> chen_graph = solver.solve()
|
||||
>>> gm.graph = chen_graph # set the graph to a new graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
|
||||
"""
|
||||
super().__init__(graph, 0, 0, True, cnode)
|
||||
self.num_grids = num_grids
|
||||
|
||||
def solve(self) -> Graph:
|
||||
"""Solve the checkpointing problem using Algorithm 3.
|
||||
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
ckpt = self.grid_search()
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
nodes = self.node_list[idx]
|
||||
for n in nodes:
|
||||
if n.op in checkpointable_op:
|
||||
n.meta['activation_checkpoint'] = i
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
"""
|
||||
ckpt_intv = []
|
||||
temp = 0
|
||||
x = 0
|
||||
y = 0
|
||||
prev_idx = 2
|
||||
for idx, nodes in enumerate(self.node_list):
|
||||
for n in nodes:
|
||||
n: Node
|
||||
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
||||
y = max(y, temp)
|
||||
if temp > b and idx > prev_idx:
|
||||
x += calculate_fwd_in(nodes[0])
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
return ckpt_intv, math.floor(math.sqrt(x * y))
|
||||
|
||||
def grid_search(self) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||
"""
|
||||
_, b_approx = self.run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
b_opt = math.inf
|
||||
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
|
||||
ckpt_intv, b_approx = self.run_chen_greedy(b)
|
||||
if b_approx < b_opt:
|
||||
b_opt = b_approx
|
||||
ckpt_opt = ckpt_intv
|
||||
return ckpt_opt
|
|
@ -0,0 +1,387 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.profiler import (
|
||||
activation_size,
|
||||
calculate_bwd_time,
|
||||
calculate_fwd_out,
|
||||
calculate_fwd_time,
|
||||
calculate_fwd_tmp,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||
|
||||
__all__ = ['CheckpointSolverBase']
|
||||
|
||||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
memory_budget: float = -1,
|
||||
parameter_size: float = 0,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverRotor`:
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size)
|
||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
memory_budget (float, optional): Memory constraint for the solution, unit is byte.
|
||||
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
"""
|
||||
super().__init__(graph, memory_budget, parameter_size, True, cnode)
|
||||
self.memory_slots = memory_slots
|
||||
|
||||
# construct chain
|
||||
unit = self.memory_budget // self.memory_slots
|
||||
self.chain = self._construct_chain(self.graph, self.node_list)
|
||||
self.chain.discretize_all(unit)
|
||||
|
||||
self.cost_table = None
|
||||
self.back_ptr = None
|
||||
self.sequence = None
|
||||
|
||||
def solve(self, force_python: bool = False) -> Graph:
|
||||
"""Solve the checkpointing problem using rotor algorithm.
|
||||
|
||||
Args:
|
||||
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
|
||||
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
chain = self.chain
|
||||
|
||||
# compute cost table
|
||||
if force_python:
|
||||
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
|
||||
else:
|
||||
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
|
||||
|
||||
# backtrack
|
||||
try:
|
||||
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr)
|
||||
self._annotate_from_sequence(self.sequence, self.node_list)
|
||||
except RuntimeError as e:
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'Checkpoint solver failed: {e}')
|
||||
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def print_chain(self):
|
||||
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
|
||||
for idx in range(len(self.node_list) - 1):
|
||||
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
|
||||
self.chain.btmp[idx])
|
||||
print(f'Chain = {self.chain}')
|
||||
|
||||
def print_sequence(self):
|
||||
print(f'Sequence = {self.sequence}')
|
||||
|
||||
@classmethod
|
||||
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
|
||||
input_tensors = cls._extract_input(graph)
|
||||
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list()
|
||||
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
node_info = cls._extract_node_info(node)
|
||||
fwd_time.append(node_info[0])
|
||||
bwd_time.append(node_info[1])
|
||||
x.append(node_info[2])
|
||||
xbar.append(node_info[3])
|
||||
ftmp.append(node_info[4])
|
||||
btmp.append(node_info[5])
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
bwd_time.append(0)
|
||||
btmp.append(0)
|
||||
|
||||
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp)
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
|
||||
"""Extract node info from a list of nodes"""
|
||||
xbar = 0
|
||||
fwd_time = 0
|
||||
bwd_time = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
# minimum flop count is required
|
||||
fwd_time += max(calculate_fwd_time(n), 1.0)
|
||||
bwd_time += max(calculate_bwd_time(n), 1.0)
|
||||
|
||||
x = calculate_fwd_out(node[-1])
|
||||
xbar = max(x, xbar)
|
||||
ftmp = cls._extract_ftmp(node)
|
||||
btmp = cls._extract_btmp(node)
|
||||
return fwd_time, bwd_time, x, xbar, ftmp, btmp
|
||||
|
||||
@staticmethod
|
||||
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
|
||||
"""Extract input tensors from a Graph"""
|
||||
input_tensors = []
|
||||
for node in graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_tensors.append(node.meta['fwd_out'])
|
||||
return input_tensors
|
||||
|
||||
@staticmethod
|
||||
def _extract_ftmp(node: List[Node]) -> int:
|
||||
"""Extract ftmp from a list of nodes"""
|
||||
n = node[-1]
|
||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||
|
||||
@staticmethod
|
||||
def _extract_btmp(node: List[Node]) -> int:
|
||||
"""Extract btmp from a list of nodes"""
|
||||
|
||||
def _extract_deps_size():
|
||||
deps_size = 0
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||
|
||||
return deps_size
|
||||
|
||||
btmp = 0
|
||||
deps = {}
|
||||
for n in reversed(node):
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
if deps[child] <= 0:
|
||||
deps[child] = float('-inf') # free
|
||||
return btmp
|
||||
|
||||
@staticmethod
|
||||
def _compute_table(chain: Chain, mem_slots: int) -> Tuple:
|
||||
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
mem_slots (int): Number of slots for discretizing memory budget.
|
||||
|
||||
Returns:
|
||||
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice
|
||||
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
|
||||
of length j
|
||||
"""
|
||||
|
||||
ftime = chain.ftime + [0.0]
|
||||
btime = chain.btime
|
||||
x = chain.x + [0]
|
||||
xbar = chain.xbar + [0]
|
||||
ftmp = chain.ftmp + [0]
|
||||
btmp = chain.btmp + [0]
|
||||
|
||||
# Build table
|
||||
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
|
||||
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mem_slots + 1):
|
||||
for i in range(chain.length + 1):
|
||||
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
|
||||
if m >= limit: # Equation (1)
|
||||
cost_table[m][i][i] = ftime[i] + btime[i]
|
||||
else:
|
||||
cost_table[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
for m in range(mem_slots + 1):
|
||||
for d in range(1, chain.length + 1):
|
||||
for i in range(chain.length + 1 - d):
|
||||
idx = i + d
|
||||
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
cost_table[m][i][idx] = float("inf")
|
||||
else:
|
||||
leaf_checkpoints = [(j,
|
||||
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= x[j]]
|
||||
if leaf_checkpoints:
|
||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
||||
else:
|
||||
best_leaf = None
|
||||
if m >= xbar[i + 1]:
|
||||
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
|
||||
else:
|
||||
chain_checkpoint = float("inf")
|
||||
if best_leaf and best_leaf[1] <= chain_checkpoint:
|
||||
cost_table[m][i][idx] = best_leaf[1]
|
||||
back_ptr[m][i][idx] = (False, best_leaf[0])
|
||||
else:
|
||||
cost_table[m][i][idx] = chain_checkpoint
|
||||
back_ptr[m][i][idx] = (True,)
|
||||
return cost_table, back_ptr
|
||||
|
||||
@staticmethod
|
||||
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple:
|
||||
raise NotImplementedError("C implementation not available yet")
|
||||
|
||||
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]],
|
||||
back_ptr: List[List[Dict[int, int]]]) -> List[int]:
|
||||
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
|
||||
|
||||
Args:
|
||||
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
|
||||
lmin (int): The left index of the interval to backtrack.
|
||||
lmax (int): The right index of the interval to backtrack.
|
||||
mem_budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
|
||||
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
|
||||
|
||||
Raises:
|
||||
ValueError: Can not process the chain.
|
||||
|
||||
Returns:
|
||||
sequence (Sequence): The sequence of executing nodes with checkpoints.
|
||||
"""
|
||||
if mem_budget <= 0:
|
||||
raise ValueError(f"Can not process a chain with negative memory {mem_budget}")
|
||||
elif cost_table[mem_budget][lmin][lmax] == float("inf"):
|
||||
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}")
|
||||
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget))
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
else:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert(Backward(lmin))
|
||||
return sequence
|
||||
|
||||
if back_ptr[mem_budget][lmin][lmax][0]:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert_sequence(
|
||||
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr))
|
||||
sequence.insert(Backward(lmin))
|
||||
else:
|
||||
j = back_ptr[mem_budget][lmin][lmax][1]
|
||||
sequence.insert(ForwardCheck(lmin))
|
||||
for k in range(lmin + 1, j):
|
||||
sequence.insert(ForwardNograd(k))
|
||||
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr))
|
||||
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr))
|
||||
return sequence
|
||||
|
||||
@staticmethod
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
|
||||
# forward annotation
|
||||
for idx, op in enumerate(fwd_list, 0):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(idx)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(
|
||||
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
|
||||
len(op_list[idx].meta['activation_checkpoint']))
|
|
@ -0,0 +1,241 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True):
|
||||
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
|
||||
See paper https://hal.inria.fr/hal-02352969 for details.
|
||||
|
||||
Args:
|
||||
ftime (List[float]): The forward time of each node.
|
||||
btime (List[float]): The backward time of each node.
|
||||
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
|
||||
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
|
||||
ftmp (List[int]): The temporary forward memory of each node.
|
||||
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
|
||||
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
|
||||
"""
|
||||
self.ftime = ftime
|
||||
self.btime = btime
|
||||
self.x = x
|
||||
self.xbar = xbar
|
||||
self.ftmp = ftmp
|
||||
self.btmp = btmp
|
||||
self.length = len(ftime)
|
||||
if check_consistency and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1)
|
||||
and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length)
|
||||
and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
|
||||
i = self.length
|
||||
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
def discretize_all(self, unit: int):
|
||||
"""Discretize the chain into a list of chains according to unit size."""
|
||||
discretizer = lambda val: math.ceil(val / unit)
|
||||
self.x = tree_map(discretizer, self.x)
|
||||
self.xbar = tree_map(discretizer, self.xbar)
|
||||
self.ftmp = tree_map(discretizer, self.ftmp)
|
||||
self.btmp = tree_map(discretizer, self.btmp)
|
||||
|
||||
|
||||
class Operation(ABC):
|
||||
name = "Op"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
def shift(self, value):
|
||||
if type(self.index) is tuple:
|
||||
self.index = tuple(x + value for x in self.index)
|
||||
else:
|
||||
self.index += value
|
||||
|
||||
|
||||
class Forward(Operation):
|
||||
name = "F"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.ftime[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class ForwardEnable(Forward):
|
||||
name = "Fe"
|
||||
|
||||
|
||||
class ForwardNograd(Forward):
|
||||
name = "Fn"
|
||||
|
||||
|
||||
class ForwardCheck(Forward):
|
||||
name = "CF"
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
def __repr__(self):
|
||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
|
||||
|
||||
def isForward(op):
|
||||
return type(op) is Forward or type(op) is Forwards
|
||||
|
||||
|
||||
class Backward(Operation):
|
||||
name = "B"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.btime[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "L"
|
||||
|
||||
def cost(self, chain):
|
||||
return 0
|
||||
|
||||
|
||||
class MemoryAccess(Operation):
|
||||
name = "MA"
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
return 0
|
||||
|
||||
|
||||
class WriteMemory(MemoryAccess):
|
||||
name = "WM"
|
||||
|
||||
|
||||
class ReadMemory(MemoryAccess):
|
||||
name = "RM"
|
||||
|
||||
|
||||
class DiscardMemory(MemoryAccess):
|
||||
name = "DM"
|
||||
|
||||
|
||||
class Function:
|
||||
|
||||
def __init__(self, name, *args):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.str_args = ','.join(str(v) for v in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}({args})".format(n=self.name, args=self.str_args)
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(self, function):
|
||||
self.sequence = [] #List of Operation and Sequence
|
||||
self.function = function #Description the function (name and parameters)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.list_operations())
|
||||
|
||||
def list_operations(self):
|
||||
op_list = []
|
||||
for x in self.sequence:
|
||||
if isinstance(x, Operation):
|
||||
op_list.append(x)
|
||||
else:
|
||||
assert isinstance(x, Sequence)
|
||||
op_list += x.list_operations()
|
||||
return op_list
|
||||
|
||||
def insert(self, operation):
|
||||
self.sequence.append(operation)
|
||||
|
||||
def remove(self, operation_index):
|
||||
del self.sequence[operation_index]
|
||||
|
||||
def insert_sequence(self, sequence):
|
||||
self.sequence.append(sequence)
|
||||
|
||||
def shift(self, value):
|
||||
for x in self.sequence:
|
||||
x.shift(value)
|
||||
return self
|
||||
|
||||
def remove_useless_write(self):
|
||||
if self.sequence:
|
||||
if isinstance(self.sequence[0], WriteMemory):
|
||||
self.remove(0)
|
||||
return self
|
||||
|
||||
def get_makespan(self, chain):
|
||||
return sum(op.cost(chain) for op in self.list_operations())
|
||||
|
||||
def without_suffix(self):
|
||||
ops = self.list_operations()
|
||||
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
|
||||
try:
|
||||
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
|
||||
except ValueError:
|
||||
last_idx = -1
|
||||
if last_idx == end_of_first_phase - 1:
|
||||
return (self, None)
|
||||
chain_length = ops[end_of_first_phase -
|
||||
1].index ## Some assumption here about the sequence (finishes with Forward_L
|
||||
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
|
||||
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
|
||||
for i in range(last_idx + 1):
|
||||
result.insert(ops[i])
|
||||
result.insert(Loss())
|
||||
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
|
||||
position = end_of_first_phase + 1 + (chain_length - i)
|
||||
assert type(ops[position]) is Backward
|
||||
assert ops[position].index == i
|
||||
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
|
||||
result.insert(ops[i])
|
||||
return (result, start_of_fwd_enable_chain)
|
|
@ -1,14 +1,37 @@
|
|||
import colossalai
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
import colossalai
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import (
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_args,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
CODEGEN_AVAILABLE = False
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
|
|||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def pack_hook_no_input(self, x):
|
||||
if getattr(x, "offload", True):
|
||||
return (x.device, x.cpu())
|
||||
|
@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
|
|||
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
Args:
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: generated context
|
||||
"""
|
||||
|
@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
act_ckpt_label = node.meta['activation_checkpoint']
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
|
@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
|
|||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
|
||||
elif current_region is not None and not 'activation_checkpoint' in node.meta:
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
|
@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
|
|||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
|
||||
region's index, offload_input is a bool type indicates whether we need to offload
|
||||
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||
|
@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
|
||||
act_offload_label = node.activation_offload
|
||||
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
|
||||
act_offload_label = node.meta['activation_offload']
|
||||
|
||||
if current_region == None:
|
||||
current_region = act_offload_label
|
||||
|
@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
|||
|
||||
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||
"""Check if the node could end the ckpt region
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
check_idx (int): the index of checkpoint level for
|
||||
check_idx (int): the index of checkpoint level for
|
||||
nested checkpoint
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if hasattr(node, "activation_checkpoint"):
|
||||
if isinstance(node.activation_checkpoint, list):
|
||||
return node.activation_checkpoint[check_idx] == None
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
if isinstance(node.meta['activation_checkpoint'], list):
|
||||
return node.meta['activation_checkpoint'][check_idx] == None
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
|
@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
|||
|
||||
def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||
"""
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_regions = []
|
||||
|
@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
if isinstance(getattr(node, 'activation_checkpoint'), int):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
if isinstance(node.meta['activation_checkpoint'], int):
|
||||
act_ckpt_label = node.meta['activation_checkpoint']
|
||||
else:
|
||||
act_ckpt_label = node.activation_checkpoint[check_idx]
|
||||
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
|
@ -287,7 +306,6 @@ def emit_ckpt_func(body,
|
|||
level=0,
|
||||
in_ckpt=False):
|
||||
"""Emit ckpt fuction in nested way
|
||||
|
||||
Args:
|
||||
body: forward code, in recursive calls, this part will be checkpoint
|
||||
functions code
|
||||
|
@ -303,8 +321,8 @@ def emit_ckpt_func(body,
|
|||
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||
|
||||
# if the current checkpoint function use int as label, using old generation method
|
||||
if isinstance(node_list[0].activation_checkpoint, int):
|
||||
label = node_list[0].activation_checkpoint
|
||||
if isinstance(node_list[0].meta['activation_checkpoint'], int):
|
||||
label = node_list[0].meta['activation_checkpoint']
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
for node in node_list:
|
||||
|
@ -313,7 +331,7 @@ def emit_ckpt_func(body,
|
|||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||
usage += "\n"
|
||||
body.append(usage)
|
||||
|
@ -322,12 +340,12 @@ def emit_ckpt_func(body,
|
|||
else:
|
||||
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if level + 1 < len(node_list[0].activation_checkpoint):
|
||||
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
@ -354,7 +372,7 @@ def emit_ckpt_func(body,
|
|||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func += ckpt_func_buffer
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
|
@ -368,7 +386,7 @@ def emit_ckpt_func(body,
|
|||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
|
@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
ckpt_func: checkpoint functions code
|
||||
|
@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
if hasattr(node_list[start_node_idx], 'activation_offload'):
|
||||
activation_offload = node_list[start_node_idx].activation_offload
|
||||
if 'activation_offload' in node_list[start_node_idx].meta:
|
||||
activation_offload = node_list[start_node_idx].meta['activation_offload']
|
||||
else:
|
||||
activation_offload = False
|
||||
|
||||
|
@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
if input_node.op != "placeholder":
|
||||
non_leaf_input = 1
|
||||
for user in input_node.users:
|
||||
if hasattr(user, "activation_checkpoint"):
|
||||
if user.activation_checkpoint == label:
|
||||
if 'activation_checkpoint' in user.meta:
|
||||
if user.meta['activation_checkpoint'] == label:
|
||||
if user.op == "call_module":
|
||||
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
|
||||
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
|
||||
|
@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
|
@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
|
||||
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
|
|||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
@ -851,10 +865,8 @@ else:
|
|||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
|
@ -999,7 +1011,7 @@ else:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
|
||||
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
|
@ -1040,7 +1052,6 @@ else:
|
|||
# in forward function
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{ckpt_func}
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
{code}"""
|
||||
|
|
|
@ -13,10 +13,10 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|||
"""Calculate activation size of a node.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
int: The activation size, unit is byte.
|
||||
"""
|
||||
act_size = 0
|
||||
if isinstance(out, torch.Tensor):
|
||||
|
@ -38,10 +38,10 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
|||
"""Calculate parameter size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`.
|
||||
|
||||
Returns:
|
||||
int: The parameter size
|
||||
int: The parameter size, unit is byte.
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
|
|
|
@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
|
||||
def pack(x):
|
||||
global cache, do_not_cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
x._node.meta['saved_tensor'] += [tensor]
|
||||
if not do_not_cache:
|
||||
cache.add(x._tensor.uuid)
|
||||
cache.add(x._tensor.data_ptr())
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
|
@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
def extract_tensor(x: Any):
|
||||
if isinstance(x, MetaTensor):
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
return tensor
|
||||
if not isinstance(x, torch.finfo):
|
||||
return x
|
||||
|
|
|
@ -87,8 +87,8 @@ def calculate_fwd_out(n: Node) -> int:
|
|||
|
||||
fwd_in = dict()
|
||||
for u in n.users:
|
||||
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
|
||||
fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
|
||||
return activation_size(intersect(fwd_in, fwd_out))
|
||||
|
||||
|
||||
|
|
|
@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
|
|||
__all__ = ['MetaTensor']
|
||||
|
||||
|
||||
def set_uuid(x):
|
||||
def set_data_ptr(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if not hasattr(x, 'uuid'):
|
||||
setattr(x, 'uuid', uuid.uuid4())
|
||||
if not x.data_ptr():
|
||||
data_ptr = uuid.uuid4()
|
||||
x.data_ptr = lambda: data_ptr
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
|
|||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
set_uuid(r._tensor)
|
||||
set_data_ptr(r._tensor)
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
|
|||
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
|
||||
# of the input
|
||||
if func in ALIAS_ATEN:
|
||||
setattr(out, 'uuid', args[0].uuid)
|
||||
out.data_ptr = args[0].data_ptr
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
|
|
|
@ -1,26 +1,28 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
tracer.py:
|
||||
tracer.py:
|
||||
Implemented a tracer which supports control flow and user-defined meta arguments.
|
||||
The implementation is partly inspired HuggingFace's fx tracer
|
||||
"""
|
||||
import enum
|
||||
import inspect
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.fx import Tracer, Node
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.proxy import Proxy, ParameterProxy
|
||||
from torch.fx import Node, Tracer
|
||||
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from typing import Optional, Dict, Any
|
||||
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .meta_patch import meta_patched_function, meta_patched_module
|
||||
from torch.fx.graph import magic_methods, reflectable_magic_methods
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
|
@ -231,7 +233,7 @@ class ColoTracer(Tracer):
|
|||
|
||||
Args:
|
||||
root (nn.Module): a `nn.Module` object to trace the computation graph
|
||||
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
||||
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
|
||||
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
|
||||
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
|
||||
"""
|
||||
|
@ -383,7 +385,7 @@ class ColoTracer(Tracer):
|
|||
|
||||
if self.inside_torch_checkpoint_func:
|
||||
# annotate the activation checkpoint module
|
||||
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
|
||||
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
@ -2,11 +2,13 @@ import copy
|
|||
import re
|
||||
from typing import Callable
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
@ -14,7 +16,6 @@ from colossalai.fx.graph_module import ColoGraphModule
|
|||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from torch.fx import GraphModule
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
@ -94,6 +95,7 @@ def _run_ckpt_solver(rank):
|
|||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
|
|||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
|
|||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
|
|||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
|
@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
|
|||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
|
|||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
|
|||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
|
@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
|
|||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 0
|
||||
assert node.meta.get('activation_checkpoint', -1) == 0
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 1
|
||||
assert node.meta.get('activation_checkpoint', -1) == 1
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
graph = tracer.trace(module)
|
||||
|
|
Loading…
Reference in New Issue