mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add activation checkpoint solver rotor (#1496)
* [fx] fix defining ckpt functions inside forward * [fx] Modify activation checkpoint codegen and add ColoGraphModule * [fx] some modification * some modifications * some modifications * some modifications * some modifications * some code modifications * [automatic_parallel] ckpt solver rotor * [fx] add ckpt_solver_rotor * [fx] modification * code refactor * code refactorpull/1500/head
parent
09c023bee2
commit
de1e716dc4
|
@ -1 +1,2 @@
|
|||
from .tracer import ColoTracer
|
||||
from .graph_module import ColoGraphModule
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
from .ckpt_solver_chen import chen_greedy
|
||||
from .linearize import linearize
|
||||
from .ckpt_solver_rotor import solver_rotor
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
from typing import List, Set, Tuple, Dict
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
from colossalai.fx.profiler import profile_function, profile_module
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
# paper link: https://hal.inria.fr/hal-02352969
|
||||
def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
Opt[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
|
||||
(False, j) if the optimal choice is a leaf checkpoint of length j
|
||||
The computation uses dynamic programming"""
|
||||
|
||||
fw = chain.fweight + [0] ## forward time
|
||||
bw = chain.bweight ## backward time, not used
|
||||
cw = chain.cweight + [0] ## size of x (and of y)
|
||||
cbw = chain.cbweight + [0] ## size of xbar
|
||||
fwd_tmp = chain.fwd_tmp + [0]
|
||||
bwd_tmp = chain.bwd_tmp + [0]
|
||||
|
||||
# Build table
|
||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mmax + 1):
|
||||
for i in range(chain.length + 1):
|
||||
#lmax-lmin = 0
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||
if m >= limit: ## Equation (1)
|
||||
opt[m][i][i] = fw[i] + bw[i]
|
||||
else:
|
||||
opt[m][i][i] = float("inf")
|
||||
|
||||
# Compute everything
|
||||
for m in range(mmax + 1):
|
||||
for d in range(1, chain.length + 1):
|
||||
for i in range(chain.length + 1 - d):
|
||||
# for idx in range(i+1, chain.length + 1):
|
||||
idx = i + d
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
opt[m][i][idx] = float("inf")
|
||||
else:
|
||||
leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= cw[j]]
|
||||
if leaf_checkpoints:
|
||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
||||
else:
|
||||
best_leaf = None
|
||||
if m >= cbw[i + 1]:
|
||||
chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
|
||||
else:
|
||||
chain_checkpoint = float("inf")
|
||||
if best_leaf and best_leaf[1] <= chain_checkpoint:
|
||||
opt[m][i][idx] = best_leaf[1]
|
||||
what[m][i][idx] = (False, best_leaf[0])
|
||||
else:
|
||||
opt[m][i][idx] = chain_checkpoint
|
||||
what[m][i][idx] = (True,)
|
||||
return (opt, what)
|
||||
|
||||
|
||||
def _rec(chain, lmin, lmax, cmem, opt_table):
|
||||
""" chain : the class describing the AC graph
|
||||
lmin : index of the first forward to execute
|
||||
lmax : upper bound index of the last forward to execute (not included)
|
||||
cmem : number of available memory slots
|
||||
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
|
||||
if cmem <= 0:
|
||||
raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
|
||||
opt, what = opt_table
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
|
||||
if opt[cmem][lmin][lmax] == float("inf"):
|
||||
raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
else:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert(Backward(lmin))
|
||||
return sequence
|
||||
|
||||
if what[cmem][lmin][lmax][0]:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table))
|
||||
sequence.insert(Backward(lmin))
|
||||
else:
|
||||
j = what[cmem][lmin][lmax][1]
|
||||
sequence.insert(ForwardCheck(lmin))
|
||||
for k in range(lmin + 1, j):
|
||||
sequence.insert(ForwardNograd(k))
|
||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table))
|
||||
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
||||
return sequence
|
||||
|
||||
|
||||
def _discretize(mem_unit, values):
|
||||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
bwd_time = []
|
||||
xbar_sizes = [data.numel() * data.element_size()]
|
||||
x_sizes = [data.numel() * data.element_size()]
|
||||
|
||||
# currently we can't get the temp memory needed in fwd and bwd
|
||||
tmp_fwd = [0] * len(node_dict)
|
||||
tmp_bwd = [0] * (len(node_dict) + 1)
|
||||
|
||||
for key in node_dict.keys():
|
||||
fwd_time.append(0)
|
||||
bwd_time.append(0)
|
||||
xbar_sizes.append(0)
|
||||
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
|
||||
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
|
||||
for node in node_dict[key]:
|
||||
fwd_time[-1] += node.__flops__
|
||||
|
||||
# currently we haven't patched the backward flops count
|
||||
bwd_time[-1] += node.__flops__ * 2
|
||||
|
||||
xbar_sizes[-1] += node.__activation__
|
||||
|
||||
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
||||
|
||||
bwd_time.append(0)
|
||||
|
||||
fwd_time = _discretize(mem_unit, fwd_time)
|
||||
bwd_time = _discretize(mem_unit, bwd_time)
|
||||
xbar_sizes = _discretize(mem_unit, xbar_sizes)
|
||||
x_sizes = _discretize(mem_unit, x_sizes)
|
||||
tmp_fwd = _discretize(mem_unit, tmp_fwd)
|
||||
tmp_bwd = _discretize(mem_unit, tmp_bwd)
|
||||
|
||||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||
|
||||
|
||||
def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule:
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = [op for op in op_list if isinstance(op, Loss)][0]
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
for idx, op in enumerate(op_list, 1):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(idx)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for idx in ckpt_region:
|
||||
for node in node_dict[idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for idx in ckpt_region:
|
||||
for node in node_dict[idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(idx)
|
||||
|
||||
|
||||
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule:
|
||||
node_dict = linearize(gm)
|
||||
mem_unit = mem_limit // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_dict, data, mem_unit)
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_dict)
|
||||
return gm
|
|
@ -0,0 +1,89 @@
|
|||
from typing import OrderedDict
|
||||
from torch.fx import GraphModule
|
||||
from collections import OrderedDict
|
||||
import pdb
|
||||
|
||||
|
||||
def linearize(gm: GraphModule) -> dict:
|
||||
status_dict = {}
|
||||
node_dict = OrderedDict()
|
||||
node_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
last_dict_len = len(status_dict)
|
||||
# remove node from users list in status_dict
|
||||
for item in status_dict.values():
|
||||
if node in item:
|
||||
item.remove(node)
|
||||
|
||||
# pop node from status_dict if it is fully used
|
||||
for key in list(status_dict):
|
||||
if len(status_dict[key]) == 0:
|
||||
status_dict.pop(key)
|
||||
|
||||
# first node in graph, it should be in n0-n1 type,
|
||||
# where n0 contains only input op, i.e. placeholder
|
||||
if last_dict_len == 0:
|
||||
node_dict[node_idx] = [node]
|
||||
status_dict[node.name] = list(node.users)
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
|
||||
continue
|
||||
|
||||
# boundary case
|
||||
if len(status_dict) == 0:
|
||||
# current node region end point = next node region start point
|
||||
# i.e. n1-n2-n3-... type node, each node contains only one op
|
||||
if last_dict_len == 1:
|
||||
if len(node_dict[node_idx]) > 0:
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
|
||||
continue
|
||||
|
||||
# n1-n2-n3, if n1 has multiple ops, the last op in n1 will be
|
||||
# the one who is able to clean all others in status_dict
|
||||
# and as the last_dict_len > 1, there are multiple ops are used
|
||||
# by this node, we view it as the end of one node and start a new node
|
||||
else:
|
||||
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
node_idx += 1
|
||||
node_dict[node_idx] = []
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
# currently I will use bigger node structure
|
||||
# if the following region is activated, the node will be smaller
|
||||
#################################################
|
||||
# if last_dict_len == 1:
|
||||
# if len(node_dict[node_idx]) > 0:
|
||||
# node_idx += 1
|
||||
# node_dict[node_idx] = [node]
|
||||
# status_dict[node.name] = list(node.users)
|
||||
#
|
||||
# continue
|
||||
#################################################
|
||||
|
||||
# in-node case, as the current node can not clean status_dict
|
||||
# we view it as in-node status, the node will be appended to the
|
||||
# current node_idx
|
||||
node_dict[node_idx].append(node)
|
||||
status_dict[node.name] = list(node.users)
|
||||
|
||||
continue
|
||||
|
||||
# If the output node use multiple nodes, there might be an
|
||||
# empty node after the output node
|
||||
if len(node_dict[node_idx]) == 0:
|
||||
node_dict.pop[node_idx]
|
||||
node_idx -= 1
|
||||
|
||||
# pop the last two nodes
|
||||
node_dict.pop(0)
|
||||
node_dict.pop(node_idx)
|
||||
return node_dict
|
|
@ -0,0 +1,229 @@
|
|||
class Chain:
|
||||
|
||||
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
|
||||
self.fweight = fw
|
||||
self.bweight = bw
|
||||
self.cweight = cw
|
||||
self.cbweight = cbw
|
||||
self.fwd_tmp = ftmp
|
||||
self.bwd_tmp = btmp
|
||||
self.length = len(fw)
|
||||
if check and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
|
||||
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
chain_list.append(
|
||||
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
|
||||
i = self.length
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
|
||||
class Operation:
|
||||
|
||||
def shift(self, value):
|
||||
if type(self.index) is tuple:
|
||||
self.index = tuple(x + value for x in self.index)
|
||||
else:
|
||||
self.index += value
|
||||
|
||||
|
||||
class Forward(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.name = "F"
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
if chain is not None:
|
||||
return chain.fweigth[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class ForwardEnable(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "Fe"
|
||||
|
||||
|
||||
class ForwardNograd(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "Fn"
|
||||
|
||||
|
||||
class ForwardCheck(Forward):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "CF"
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
def __repr__(self):
|
||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||
|
||||
def cost(self, chain):
|
||||
if chain is not None:
|
||||
return sum(chain.fweigth[self.index[0]:self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
|
||||
|
||||
def isForward(op):
|
||||
return type(op) is Forward or type(op) is Forwards
|
||||
|
||||
|
||||
class Backward(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return "B_{i}".format(i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
if chain is not None:
|
||||
return chain.bweigth[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "L"
|
||||
|
||||
def cost(self, chain):
|
||||
return 0
|
||||
|
||||
|
||||
class MemoryAccess(Operation):
|
||||
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
return 0
|
||||
|
||||
|
||||
class WriteMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "WM"
|
||||
|
||||
|
||||
class ReadMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "RM"
|
||||
|
||||
|
||||
class DiscardMemory(MemoryAccess):
|
||||
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.name = "DM"
|
||||
|
||||
|
||||
class Function:
|
||||
|
||||
def __init__(self, name, *args):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.str_args = ','.join(str(v) for v in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return "{n}({args})".format(n=self.name, args=self.str_args)
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(self, function):
|
||||
self.sequence = [] #List of Operation and Sequence
|
||||
self.function = function #Description the function (name and parameters)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.list_operations())
|
||||
|
||||
def list_operations(self):
|
||||
op_list = []
|
||||
for x in self.sequence:
|
||||
if isinstance(x, Operation):
|
||||
op_list.append(x)
|
||||
else:
|
||||
assert isinstance(x, Sequence)
|
||||
op_list += x.list_operations()
|
||||
return op_list
|
||||
|
||||
def insert(self, operation):
|
||||
self.sequence.append(operation)
|
||||
|
||||
def remove(self, operation_index):
|
||||
del self.sequence[operation_index]
|
||||
|
||||
def insert_sequence(self, sequence):
|
||||
self.sequence.append(sequence)
|
||||
|
||||
def shift(self, value):
|
||||
for x in self.sequence:
|
||||
x.shift(value)
|
||||
return self
|
||||
|
||||
def remove_useless_write(self):
|
||||
if self.sequence:
|
||||
if isinstance(self.sequence[0], WriteMemory):
|
||||
self.remove(0)
|
||||
return self
|
||||
|
||||
def get_makespan(self, chain):
|
||||
return sum(op.cost(chain) for op in self.list_operations())
|
||||
|
||||
def without_suffix(self):
|
||||
ops = self.list_operations()
|
||||
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
|
||||
try:
|
||||
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
|
||||
except ValueError:
|
||||
last_idx = -1
|
||||
if last_idx == end_of_first_phase - 1:
|
||||
return (self, None)
|
||||
chain_length = ops[end_of_first_phase -
|
||||
1].index ## Some assumption here about the sequence (finishes with Forward_L
|
||||
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
|
||||
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
|
||||
for i in range(last_idx + 1):
|
||||
result.insert(ops[i])
|
||||
result.insert(Loss())
|
||||
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
|
||||
position = end_of_first_phase + 1 + (chain_length - i)
|
||||
assert type(ops[position]) is Backward
|
||||
assert ops[position].index == i
|
||||
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
|
||||
result.insert(ops[i])
|
||||
return (result, start_of_fwd_enable_chain)
|
|
@ -9,7 +9,7 @@ import colossalai
|
|||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
|
@ -22,7 +22,7 @@ except:
|
|||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
SOLVERS = [chen_greedy]
|
||||
SOLVERS = [chen_greedy, solver_rotor]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
|
@ -77,7 +77,10 @@ def _run_ckpt_solver(rank):
|
|||
MetaInfoProp(gm).run(data)
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
gm = solver(gm)
|
||||
if solver == solver_rotor:
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
||||
else:
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
|
@ -106,7 +109,10 @@ def _run_ckpt_solver_torch11(rank):
|
|||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
gm = solver(gm)
|
||||
if solver == solver_rotor:
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
||||
else:
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
|
|
Loading…
Reference in New Issue