[hotfix/rotor] fix variable names (#1597)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.

* [fx] fix variable names in ckpt_solvers and unskip tests.

* [fx] commit my changes.

* [fx] restore skips.

* [fx] restore skips.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.
pull/1604/head
Super Daniel 2022-09-14 14:27:04 +08:00 committed by GitHub
parent faa23b9d9a
commit c8e9b2ad78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 85 additions and 83 deletions

View File

@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0 y = 0
prev_idx = 2 prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes): for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'fwd_out') n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
y = max(y, temp) y = max(y, temp)
if temp > b and n in ckpt_nodes: if temp > b and n in ckpt_nodes:
x += getattr(n, 'fwd_out') x += n.meta['fwd_mem_out']
temp = 0 temp = 0
ckpt_intv.append((prev_idx, idx + 1)) ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1 prev_idx = idx + 1

View File

@ -1,11 +1,11 @@
from typing import List, Set, Tuple, Dict from typing import List, Tuple
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math import math
from .linearize import linearize from .linearize import linearize
from .utils import * from .utils import *
from colossalai.fx.profiler import profile_function, profile_module
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
bw = chain.bweight ## backward time, not used bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y) cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar cbw = chain.cbweight + [0] ## size of xbar
fwd_tmp = chain.fwd_tmp + [0] fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_tmp = chain.bwd_tmp + [0] bwd_mem_tmp = chain.bwd_mem_tmp + [0]
# Build table # Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for m in range(mmax + 1): for m in range(mmax + 1):
for i in range(chain.length + 1): for i in range(chain.length + 1):
#lmax-lmin = 0 #lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1) if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i] opt[m][i][i] = fw[i] + bw[i]
else: else:
@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for i in range(chain.length + 1 - d): for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1): # for idx in range(i+1, chain.length + 1):
idx = i + d idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1: if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin: if m < mmin:
opt[m][i][idx] = float("inf") opt[m][i][idx] = float("inf")
else: else:
@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0 xbar = 0
for n in node: for n in node:
xbar += n.fwd_tmp + n.fwd_out xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
return xbar return xbar
@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
fwd_time = 0 fwd_time = 0
for n in node: for n in node:
# minimum flop count is needed # minimum flop count is needed
fwd_time += max(n.fwd_flop, 1) fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time return fwd_time
@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
bwd_time = 0 bwd_time = 0
for n in node: for n in node:
# minimum flop count is needed # minimum flop count is needed
bwd_time += max(n.bwd_flop, 1) bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time return bwd_time
def _get_bwd_tmp(node: List[Node]) -> int: def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node """Get the backward temp memory of a node
Args: Args:
@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:
def _get_deps_size(): def _get_deps_size():
deps_size = 0 deps_size = 0
for key in deps.keys(): for k, v in deps.items():
deps_size += key.bwd_out if v > 0:
deps_size += k.meta['bwd_mem_out']
return deps_size return deps_size
bwd_tmp = 0 bwd_mem_tmp = 0
deps = {} deps = {}
# add all the users for last node into deps, # add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory # as those nodes' gradient out will be stored in memory
for son in node[-1].users: for child in node[-1].users:
deps[son] = 1 deps[child] = 1
for n in reversed(node): for n in reversed(node):
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp) bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n._input_nodes)
for son in n.users: deps[n] = len(n.all_input_nodes)
deps[son] -= 1 for child in n.users:
if child in deps:
deps[child] -= 1
for key in list(deps.keys()): for key in list(deps.keys()):
if deps[key] == 0: if deps[key] == 0:
del deps[key] del deps[key]
return bwd_tmp return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
bwd_time.append(_bwd_time(node)) bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node)) x_sizes.append(_compute_output_size(node))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_tmp(node)) tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0 # if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]): if len(node) == 1 and _get_inplace(node[0]):
@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
""" """
node_list = linearize(gm, cnode) node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots mem_unit = mem_limit * (1.0 - eps) // mem_slots
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data, mem_unit) chain: Chain = _construct_chain(node_list, data, mem_unit)

View File

@ -5,24 +5,24 @@ class Chain:
self.bweight = bw self.bweight = bw
self.cweight = cw self.cweight = cw
self.cbweight = cbw self.cbweight = cbw
self.fwd_tmp = ftmp self.fwd_mem_tmp = ftmp
self.bwd_tmp = btmp self.bwd_mem_tmp = btmp
self.length = len(fw) self.length = len(fw)
if check and not self.check_lengths(): if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths") raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self): def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length) and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self): def __repr__(self):
chain_list = [] chain_list = []
for i in range(self.length): for i in range(self.length):
chain_list.append( chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i])) self.bwd_mem_tmp[i]))
i = self.length i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i])) chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__() return chain_list.__repr__()

View File

@ -94,12 +94,11 @@ class MetaInfoProp(torch.fx.Interpreter):
tensor_meta = tree_map(extract_tensor_meta, result) tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future # TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes: for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result) n.meta['type'] = type(result)
# retain the autograd graph # retain the autograd graph

View File

@ -1,11 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial
from typing import Dict from typing import Dict
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .memory import activation_size from .memory import activation_size
class Stage(Enum): class Phase(Enum):
FORWARD = 0 FORWARD = 0
LOSS = 1 LOSS = 1
BACKWARD = 2 BACKWARD = 2
@ -48,24 +49,9 @@ class GraphInfo:
bwd_mem_out: int = 0 bwd_mem_out: int = 0
def is_forward(n: Node): def is_phase(n: Node, phase: Phase) -> bool:
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['stage'] == Stage.FORWARD return n.meta['phase'] == phase
def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS
def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER
def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
def is_saved(n: Node): def is_saved(n: Node):
@ -74,7 +60,7 @@ def is_saved(n: Node):
def autograd_graph_analysis(graph: Graph) -> GraphInfo: def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage. """Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. Basically the input graph should have all nodes marked for keyword `phase`.
Nodes should have attribute `out` indicating the output of each node. Nodes should have attribute `out` indicating the output of each node.
============================================================================ ============================================================================
Placeholder ----> p o <---- We need to keep track of grad out Placeholder ----> p o <---- We need to keep track of grad out
@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
l l
============================================================================= =============================================================================
Args: Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. graph (Graph): The autograd graph with nodes marked for keyword `phase`.
Returns: Returns:
graph_info (GraphInfo): Meta information for the dataflow. graph_info (GraphInfo): Meta information for the dataflow.
""" """
def _peak_memory(deps: Dict[Node, int]): def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0 peak_mem = 0
for k, v in deps.items(): for k, v in deps.items():
if v > 0: if v > 0:
bwd_tmp += activation_size(k.meta['out']) peak_mem += activation_size(k.meta['out'])
return bwd_tmp return peak_mem
# deps is used to track all the memory dependencies of the graph. # deps is used to track all the memory dependencies of the graph.
deps = {} deps = {}
@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes: for n in graph.nodes:
n: Node n: Node
if is_saved(n) and not any(map(is_loss, n.users)): if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
# A forward tensor who is marked `save` but is not # A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward. # an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`. # If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_in` should be kept in memory even this function # Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed. # is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed. # the node, `fwd_mem_tmp` can be freed.
if is_placeholder(n): if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out']) graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n): if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out']) graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n): elif is_phase(n, Phase.BACKWARD):
if len(n.users): if len(n.users):
# liveness analysis is only used in backward # liveness analysis is only used in backward
deps[n] = len(n.users) deps[n] = len(n.users)

View File

@ -5,8 +5,8 @@ import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, Stage from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS from .memory import WEIRD_OPS, activation_size
from .tensor import MetaTensor from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results. # `flop_count`` serves as a global dictionary to store results.
flop_count = { flop_count = {
Stage.FORWARD: 0, Phase.FORWARD: 0,
Stage.LOSS: 0, Phase.LOSS: 0,
Stage.BACKWARD: 0, Phase.BACKWARD: 0,
} }
# `stage` will mark the stage of autograd from outside scope.
stage = Stage.FORWARD
# FlopTensor not only get the flop statistics of a single node, # FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node. # it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and # This makes sure we can analyze the dependencies of memory, and
@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# run aten for backend=CPU but actually on backend=Meta # run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs) out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out) node.meta['out'] = normalize_tuple(out)
node.meta['stage'] = stage node.meta['phase'] = phase
def wrap(x): def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
x._node = subgraph.create_node('placeholder', x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,), 'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor)) name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['stage'] = Stage.PLACEHOLDER x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['out'] = (x._tensor,) x._node.meta['out'] = (x._tensor,)
tree_map(set_placeholder, args) tree_map(set_placeholder, args)
@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
def unpack(x): def unpack(x):
return x return x
# `phase` will mark the phase of autograd from outside scope.
phase = Phase.FORWARD
# mark saved tensors with saved_tensors_hooks # mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack): with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str): if isinstance(target, str):
@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# If the output is not a floating point `torch.Tensor` or it does not # If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node. # requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad: if is_autogradable(out) and out.requires_grad:
stage = Stage.LOSS phase = Phase.LOSS
loss = out.sum() loss = out.sum()
stage = Stage.BACKWARD phase = Phase.BACKWARD
loss.backward() loss.backward()
graph_info = autograd_graph_analysis(subgraph) graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD] graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
def unwrap(x): def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable:
# If there is an argument that this `call_function` is inplace, we should # If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling. # skip the autograd profiling.
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, **kwargs) out, meta = _profile(func, *args, **kwargs)
return out, meta return out, meta
@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable:
# If there is an argument that this `call_module` is inplace, we should # If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling. # skip the autograd profiling.
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, meta return out, meta

View File

@ -38,7 +38,8 @@ def test_linearize():
if isinstance(op, ForwardNograd): if isinstance(op, ForwardNograd):
for n in node_list[idx]: for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
continue continue
@ -54,7 +55,8 @@ def test_linearize():
ckpt_idx += 1 ckpt_idx += 1
for n in node_list[idx]: for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
continue continue
@ -63,7 +65,8 @@ def test_linearize():
in_ckpt = True in_ckpt = True
for n in node_list[idx]: for n in node_list[idx]:
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" assert n.activation_checkpoint[
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
del model del model
del gm del gm