mirror of https://github.com/hpcaitech/ColossalAI
[hotfix/rotor] fix variable names (#1597)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory. * [fx] fix variable names in ckpt_solvers and unskip tests. * [fx] commit my changes. * [fx] restore skips. * [fx] restore skips. * [fx] chaange stage into phase. * [fx] chaange stage into phase. * [fx] chaange stage into phase.pull/1604/head
parent
faa23b9d9a
commit
c8e9b2ad78
|
@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||
y = 0
|
||||
prev_idx = 2
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
temp += getattr(n, 'fwd_out')
|
||||
n: Node
|
||||
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
|
||||
y = max(y, temp)
|
||||
if temp > b and n in ckpt_nodes:
|
||||
x += getattr(n, 'fwd_out')
|
||||
x += n.meta['fwd_mem_out']
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from typing import List, Set, Tuple, Dict
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
from colossalai.fx.profiler import profile_function, profile_module
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
|
||||
|
@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||
bw = chain.bweight ## backward time, not used
|
||||
cw = chain.cweight + [0] ## size of x (and of y)
|
||||
cbw = chain.cbweight + [0] ## size of xbar
|
||||
fwd_tmp = chain.fwd_tmp + [0]
|
||||
bwd_tmp = chain.bwd_tmp + [0]
|
||||
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
|
||||
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
|
||||
|
||||
# Build table
|
||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
|
@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||
for m in range(mmax + 1):
|
||||
for i in range(chain.length + 1):
|
||||
#lmax-lmin = 0
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
|
||||
if m >= limit: ## Equation (1)
|
||||
opt[m][i][i] = fw[i] + bw[i]
|
||||
else:
|
||||
|
@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||
for i in range(chain.length + 1 - d):
|
||||
# for idx in range(i+1, chain.length + 1):
|
||||
idx = i + d
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
opt[m][i][idx] = float("inf")
|
||||
else:
|
||||
|
@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
|||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.fwd_tmp + n.fwd_out
|
||||
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
|
||||
return xbar
|
||||
|
||||
|
||||
|
@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
|
|||
fwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
fwd_time += max(n.fwd_flop, 1)
|
||||
fwd_time += max(n.meta['fwd_flop'], 1)
|
||||
return fwd_time
|
||||
|
||||
|
||||
|
@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
|
|||
bwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
bwd_time += max(n.bwd_flop, 1)
|
||||
bwd_time += max(n.meta['bwd_flop'], 1)
|
||||
return bwd_time
|
||||
|
||||
|
||||
def _get_bwd_tmp(node: List[Node]) -> int:
|
||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the backward temp memory of a node
|
||||
|
||||
Args:
|
||||
|
@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:
|
|||
|
||||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for key in deps.keys():
|
||||
deps_size += key.bwd_out
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
|
||||
return deps_size
|
||||
|
||||
bwd_tmp = 0
|
||||
bwd_mem_tmp = 0
|
||||
deps = {}
|
||||
|
||||
# add all the users for last node into deps,
|
||||
# as those nodes' gradient out will be stored in memory
|
||||
for son in node[-1].users:
|
||||
deps[son] = 1
|
||||
for child in node[-1].users:
|
||||
deps[child] = 1
|
||||
for n in reversed(node):
|
||||
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
|
||||
deps[n] = len(n._input_nodes)
|
||||
for son in n.users:
|
||||
deps[son] -= 1
|
||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
|
||||
for key in list(deps.keys()):
|
||||
if deps[key] == 0:
|
||||
del deps[key]
|
||||
|
||||
return bwd_tmp
|
||||
return bwd_mem_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
|
@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
|||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(_compute_output_size(node))
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_bwd.append(_get_bwd_tmp(node))
|
||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||
|
||||
# if a node with only one inplace op, we need to let x_bar = 0
|
||||
if len(node) == 1 and _get_inplace(node[0]):
|
||||
|
@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||
|
|
|
@ -5,24 +5,24 @@ class Chain:
|
|||
self.bweight = bw
|
||||
self.cweight = cw
|
||||
self.cbweight = cbw
|
||||
self.fwd_tmp = ftmp
|
||||
self.bwd_tmp = btmp
|
||||
self.fwd_mem_tmp = ftmp
|
||||
self.bwd_mem_tmp = btmp
|
||||
self.length = len(fw)
|
||||
if check and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
|
||||
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
|
||||
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
chain_list.append(
|
||||
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
|
||||
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
|
||||
self.bwd_mem_tmp[i]))
|
||||
i = self.length
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
|
||||
|
|
|
@ -94,12 +94,11 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
|
||||
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
for par in n.all_input_nodes:
|
||||
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
|
||||
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
LOSS = 1
|
||||
BACKWARD = 2
|
||||
|
@ -48,24 +49,9 @@ class GraphInfo:
|
|||
bwd_mem_out: int = 0
|
||||
|
||||
|
||||
def is_forward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.FORWARD
|
||||
|
||||
|
||||
def is_loss(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.LOSS
|
||||
|
||||
|
||||
def is_placeholder(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.PLACEHOLDER
|
||||
|
||||
|
||||
def is_backward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.BACKWARD
|
||||
def is_phase(n: Node, phase: Phase) -> bool:
|
||||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
|
||||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
|
@ -74,7 +60,7 @@ def is_saved(n: Node):
|
|||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
Nodes should have attribute `out` indicating the output of each node.
|
||||
============================================================================
|
||||
Placeholder ----> p o <---- We need to keep track of grad out
|
||||
|
@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
l
|
||||
=============================================================================
|
||||
Args:
|
||||
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
|
||||
|
||||
Returns:
|
||||
graph_info (GraphInfo): Meta information for the dataflow.
|
||||
"""
|
||||
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
bwd_tmp = 0
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
bwd_tmp += activation_size(k.meta['out'])
|
||||
return bwd_tmp
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
deps = {}
|
||||
|
@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and not any(map(is_loss, n.users)):
|
||||
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_in`.
|
||||
# Any `fwd_in` should be kept in memory even this function
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
# Any `fwd_mem_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
|
||||
# the node, `fwd_tmp` can be freed.
|
||||
if is_placeholder(n):
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||
if is_forward(n):
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
elif is_backward(n):
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
# liveness analysis is only used in backward
|
||||
deps[n] = len(n.users)
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch
|
|||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from .dataflow import autograd_graph_analysis, Stage
|
||||
from .memory import WEIRD_OPS
|
||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
||||
from .memory import WEIRD_OPS, activation_size
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
|
@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|||
|
||||
# `flop_count`` serves as a global dictionary to store results.
|
||||
flop_count = {
|
||||
Stage.FORWARD: 0,
|
||||
Stage.LOSS: 0,
|
||||
Stage.BACKWARD: 0,
|
||||
Phase.FORWARD: 0,
|
||||
Phase.LOSS: 0,
|
||||
Phase.BACKWARD: 0,
|
||||
}
|
||||
|
||||
# `stage` will mark the stage of autograd from outside scope.
|
||||
stage = Stage.FORWARD
|
||||
|
||||
# FlopTensor not only get the flop statistics of a single node,
|
||||
# it also build a full autograd graph for this node.
|
||||
# This makes sure we can analyze the dependencies of memory, and
|
||||
|
@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||
node.meta['out'] = normalize_tuple(out)
|
||||
node.meta['stage'] = stage
|
||||
node.meta['phase'] = phase
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||
|
@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|||
x._node = subgraph.create_node('placeholder',
|
||||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['stage'] = Stage.PLACEHOLDER
|
||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||
x._node.meta['out'] = (x._tensor,)
|
||||
|
||||
tree_map(set_placeholder, args)
|
||||
|
@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|||
def unpack(x):
|
||||
return x
|
||||
|
||||
# `phase` will mark the phase of autograd from outside scope.
|
||||
phase = Phase.FORWARD
|
||||
# mark saved tensors with saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
if isinstance(target, str):
|
||||
|
@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
|||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
stage = Stage.LOSS
|
||||
phase = Phase.LOSS
|
||||
loss = out.sum()
|
||||
stage = Stage.BACKWARD
|
||||
phase = Phase.BACKWARD
|
||||
loss.backward()
|
||||
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||
|
||||
def unwrap(x):
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
|
@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable:
|
|||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if kwargs.get('inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
|
@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if getattr(module, 'inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
||||
return out, meta
|
||||
|
||||
|
|
|
@ -38,7 +38,8 @@ def test_linearize():
|
|||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
|
@ -54,7 +55,8 @@ def test_linearize():
|
|||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
|
@ -63,7 +65,8 @@ def test_linearize():
|
|||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
assert n.activation_checkpoint[
|
||||
0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
|
|
Loading…
Reference in New Issue