mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add pofo solver (#1608)
* [fx] add pofo algorithm * [fx] Add pofo solver * [fx] code refactor * [fx] fix test_linearize importpull/1611/head
parent
d32cf84c46
commit
933b6c6367
|
@ -1,3 +1,4 @@
|
|||
from .ckpt_solver_chen import chen_greedy
|
||||
from .linearize import linearize
|
||||
from .ckpt_solver_rotor import solver_rotor
|
||||
from .ckpt_solver_pofo import solver_pofo
|
||||
|
|
|
@ -0,0 +1,404 @@
|
|||
from typing import List, Tuple
|
||||
import copy
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
|
||||
def _normalize_flops(chain: Chain, flops) -> Chain:
|
||||
"""
|
||||
Normalize flops
|
||||
"""
|
||||
for i in range(chain.length):
|
||||
chain.fweight[i] /= flops
|
||||
chain.bweight[i] /= flops
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
class PofoTable:
|
||||
"""PofoTable
|
||||
The PofoTable contains the necessary components to store intermediate results
|
||||
of dynamic programming and the operations alone the way.
|
||||
"""
|
||||
|
||||
def __init__(self, chain_length: int, mem_slots: int):
|
||||
"""Init pofo table
|
||||
The pofo table contains two tables, opt and what, indicating values and
|
||||
operations.
|
||||
|
||||
Args:
|
||||
chain_length (int): chain length
|
||||
mem_slots (int): number of memory slots
|
||||
"""
|
||||
|
||||
self.length = chain_length
|
||||
self.mem_slots = mem_slots
|
||||
|
||||
# initializing tables
|
||||
# the first bool indicates whether the input has bar
|
||||
# opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
|
||||
# what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
|
||||
# where is_enable indicates whether we enable the gradient, is_offload indicates whether we
|
||||
# offload the input, index indicates the end of F_\empty sequence if is_enable = False
|
||||
self.opt = {
|
||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
||||
}
|
||||
self.what = {
|
||||
False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
|
||||
True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
|
||||
}
|
||||
|
||||
def _get_value(self, state, table, default):
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
if act_size + df > self.mem_slots or act_size + db > self.mem_slots:
|
||||
return default
|
||||
|
||||
try:
|
||||
return table[input_has_bar][i][act_size][(df, db)]
|
||||
except KeyError:
|
||||
print(f"state not found {state}")
|
||||
|
||||
def get_opt(self, state):
|
||||
return self._get_value(state, self.opt, INF)
|
||||
|
||||
def get_what(self, state):
|
||||
return self._get_value(state, self.what, INF)
|
||||
|
||||
def set_value(self, state, opt, what):
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
self.opt[input_has_bar][i][act_size][(df, db)] = opt
|
||||
self.what[input_has_bar][i][act_size][(df, db)] = what
|
||||
|
||||
|
||||
class PofoSolver:
|
||||
"""PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
||||
The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
|
||||
and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
|
||||
rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
|
||||
"""
|
||||
|
||||
def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None:
|
||||
self.chain = chain
|
||||
self.length = chain.length
|
||||
self.max_memory = max_memory
|
||||
self.mem_slots = mem_slots
|
||||
self.mem_unit = max_memory / mem_slots
|
||||
self.bandwidth = bandwidth
|
||||
|
||||
self.disc_chain = copy.deepcopy(self.chain)
|
||||
|
||||
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
|
||||
self._compute_pofo_table()
|
||||
|
||||
def _discretize(self, *values) -> Tuple:
|
||||
return tuple(math.ceil(value / self.mem_unit) for value in values)
|
||||
|
||||
def _undiscretize(self, *discrete_values) -> Tuple:
|
||||
if len(discrete_values) == 1:
|
||||
return discrete_values[0] * self.mem_unit
|
||||
else:
|
||||
return tuple(d * self.mem_unit for d in discrete_values)
|
||||
|
||||
def _mmax_all(self, idx: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of Fi_all
|
||||
"""
|
||||
|
||||
return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx]
|
||||
|
||||
def _mmax_b(self, idx: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of Bi
|
||||
"""
|
||||
|
||||
return self.chain.cbweight[idx +
|
||||
1] + self.chain.cweight[idx +
|
||||
1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx]
|
||||
|
||||
def _mmax_ng(self, i: int, j: int):
|
||||
"""
|
||||
Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
|
||||
"""
|
||||
|
||||
res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j]
|
||||
if j > i:
|
||||
res += self.chain.cweight[j]
|
||||
return res
|
||||
|
||||
def _rotor_estimated_bwd(self, i, j, m, delta):
|
||||
compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j]
|
||||
comm = delta / self.bandwidth
|
||||
return (max(compute, comm) + compute + comm) / 2
|
||||
|
||||
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
|
||||
return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table)
|
||||
|
||||
def _common_values_enable(self, state: Tuple):
|
||||
|
||||
idx, act_size, df, db, input_has_bar = state
|
||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
||||
mf = act_size + df + input_size
|
||||
mb = act_size + db + input_size
|
||||
mem_avail = self.max_memory - act_size - input_size
|
||||
f_usage = self._mmax_all(idx)
|
||||
b_usage = self._mmax_b(idx)
|
||||
|
||||
# infeasible
|
||||
if f_usage > mem_avail or b_usage > mem_avail:
|
||||
return None
|
||||
|
||||
# calculate idle time
|
||||
eps_f_beta = max(0, f_usage - self.max_memory + mf)
|
||||
eps_b_beta = max(0, b_usage - self.max_memory + mb)
|
||||
idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth
|
||||
|
||||
# calculate offload and prefetch data
|
||||
offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta
|
||||
prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta
|
||||
|
||||
# total_time
|
||||
total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time
|
||||
|
||||
return (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False):
|
||||
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
|
||||
# compute new epsilon_tmp and sum_fwds
|
||||
if iterative:
|
||||
self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds)
|
||||
self.sum_fwds += self.chain.fweight[j]
|
||||
else:
|
||||
self.epsilon_tmp = max(
|
||||
self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1))
|
||||
self.sum_fwds = sum(self.chain.fweight[i:j + 1])
|
||||
|
||||
input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]
|
||||
mf = act_size + df + input_size
|
||||
mem_avail = self.max_memory - act_size - input_size
|
||||
|
||||
# if infeasible
|
||||
if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail:
|
||||
return None
|
||||
|
||||
eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf)
|
||||
offload_data = self.sum_fwds * self.bandwidth + eps_f_beta
|
||||
|
||||
# TODO: Implement the precise backward recompute sequence mentioned in the paper
|
||||
# currently we will use an approximate way to get the backward time
|
||||
time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db)
|
||||
|
||||
prefetch_data = time_backward * self.bandwidth
|
||||
idle_time = eps_f_beta / self.bandwidth
|
||||
total_time = self.sum_fwds + idle_time + time_backward
|
||||
|
||||
return (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple:
|
||||
"""Generate new values for next state
|
||||
|
||||
Args:
|
||||
state (Tuple): undiscretized states
|
||||
do_offload (bool): bool type indicates whether we need to do offload
|
||||
common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
|
||||
|
||||
Returns:
|
||||
Tuple: (new_act_size, new_df, new_db)
|
||||
"""
|
||||
idx, act_size, df, db, input_has_bar = state
|
||||
offload_data, prefetch_data, *_ = common_values
|
||||
input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
|
||||
if do_offload:
|
||||
new_act_size = act_size
|
||||
new_df = max(0, df + input_size - offload_data)
|
||||
new_db = max(0, db - prefetch_data) + input_size
|
||||
else:
|
||||
new_act_size = act_size + input_size
|
||||
new_df = max(0, df - offload_data)
|
||||
new_db = max(0, db - prefetch_data)
|
||||
|
||||
return (new_act_size, new_df, new_db)
|
||||
|
||||
def _compute_pofo_table(self):
|
||||
self.table = PofoTable(self.length, self.mem_slots)
|
||||
|
||||
# initializing the loss
|
||||
for act_size in range(self.mem_slots + 1):
|
||||
for df in range(self.mem_slots - act_size + 1):
|
||||
for db in range(self.mem_slots - act_size + 1):
|
||||
# undiscretize for idle time calculation
|
||||
origin_values = self._undiscretize(act_size, df, db)
|
||||
|
||||
for input_has_bar in (False, True):
|
||||
disc_state = (self.length, act_size, df, db, input_has_bar)
|
||||
state = (self.length, *origin_values, input_has_bar)
|
||||
common_values = self._common_values_enable(state)
|
||||
|
||||
# if no feasible choice
|
||||
if common_values is None:
|
||||
self.table.set_value(disc_state, INF, None)
|
||||
continue
|
||||
|
||||
# if there is feasible choice
|
||||
new_act_size, new_df, new_db = self._new_values(state, False, common_values)
|
||||
eps_g = (new_df + new_db) / self.bandwidth
|
||||
total_time = common_values[2] + eps_g
|
||||
self.table.set_value(disc_state, total_time, (True, False))
|
||||
|
||||
# main loop
|
||||
for i in reversed(range(self.length)):
|
||||
for act_size in range(self.mem_slots + 1):
|
||||
for df in range(self.mem_slots - act_size + 1):
|
||||
for db in range(self.mem_slots - act_size + 1):
|
||||
# undiscretize for idle time calculation
|
||||
origin_values = self._undiscretize(act_size, df, db)
|
||||
|
||||
for input_has_bar in (False, True):
|
||||
best_result = INF
|
||||
best_choice = None
|
||||
disc_state = (i, act_size, df, db, input_has_bar)
|
||||
state = (i, *origin_values, input_has_bar)
|
||||
|
||||
# case 1: start with F_all
|
||||
vals_enable = self._common_values_enable(state)
|
||||
if vals_enable is not None:
|
||||
for do_offload in (True, False):
|
||||
new_state = self._new_values(state, do_offload, vals_enable)
|
||||
new_state = (i + 1, *self._discretize(*new_state), True)
|
||||
total_time = vals_enable[2]
|
||||
results_all = self.table.get_opt(new_state) + total_time
|
||||
if results_all < best_result:
|
||||
best_result = results_all
|
||||
best_choice = (True, do_offload)
|
||||
|
||||
# case 2: start with F_ck
|
||||
self.sum_fwds = 0
|
||||
self.epsilon_tmp = 0
|
||||
for j in range(i, self.length):
|
||||
vals_nograd = self._common_values_nograd(state, j, True)
|
||||
|
||||
# if infeasible
|
||||
if vals_nograd is None:
|
||||
continue
|
||||
|
||||
for do_offload in (True, False):
|
||||
new_state = self._new_values(state, do_offload, vals_nograd)
|
||||
new_state = (j + 1, *self._discretize(*new_state), False)
|
||||
total_time = vals_nograd[2]
|
||||
result_nograd = total_time + self.table.get_opt(new_state)
|
||||
if result_nograd < best_result:
|
||||
best_result = result_nograd
|
||||
best_choice = (False, do_offload, j)
|
||||
|
||||
self.table.set_value(disc_state, best_result, best_choice)
|
||||
|
||||
def pofo_rec(self, disc_state):
|
||||
i, act_size, df, db, input_has_bar = disc_state
|
||||
result = Sequence(Function("pofo", *disc_state))
|
||||
what = self.table.get_what(disc_state)
|
||||
state = self._undiscretize(act_size, df, db)
|
||||
state = (i, *state, input_has_bar)
|
||||
i, act_size, df, db, input_has_bar = state
|
||||
|
||||
if what is None:
|
||||
return None
|
||||
|
||||
# if loss
|
||||
if i == self.length:
|
||||
result.insert(Loss())
|
||||
return result
|
||||
|
||||
if what[0]:
|
||||
do_offload = what[1]
|
||||
values = self._common_values_enable(state)
|
||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
||||
new_state = (i + 1, *new_state, True)
|
||||
if do_offload:
|
||||
result.insert(Offload(i, input_has_bar))
|
||||
result.insert(ForwardEnable(i))
|
||||
result.insert_sequence(self.pofo_rec(new_state))
|
||||
if do_offload:
|
||||
result.insert(Prefetch(i, input_has_bar))
|
||||
result.insert(Backward(i))
|
||||
|
||||
else:
|
||||
_, do_offload, j = what
|
||||
values = self._common_values_nograd(state, j)
|
||||
new_state = self._discretize(*self._new_values(state, do_offload, values))
|
||||
new_state = (j + 1, *new_state, False)
|
||||
if do_offload:
|
||||
result.insert(Offload(i, input_has_bar))
|
||||
result.insert(ForwardCheck(i))
|
||||
for k in range(i + 1, j + 1):
|
||||
result.insert(ForwardNograd(k))
|
||||
result.insert_sequence(self.pofo_rec(new_state))
|
||||
if do_offload:
|
||||
result.insert(Prefetch(i, input_has_bar))
|
||||
m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i])
|
||||
|
||||
#TODO: Implement the precise backward recompute sequence mentioned in the paper
|
||||
result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def solver_pofo(gm: ColoGraphModule,
|
||||
data,
|
||||
bandwidth,
|
||||
flops,
|
||||
mem_limit: int,
|
||||
mem_slots: int = 50,
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.0) -> ColoGraphModule:
|
||||
"""Solver that combine offload and activation checkpoint
|
||||
Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
|
||||
|
||||
Args:
|
||||
gm (ColoGraphModule): ColoGraphModule derived from tracer
|
||||
data: input of the model
|
||||
bandwidth: offload bandwidth, unit Byte/s
|
||||
flops: FLOPS of device, unit FLOPs/s
|
||||
mem_limit (int): memory limit, unit Byte
|
||||
mem_slots (int, optional): number of memory slots. Defaults to 500.
|
||||
cnode (List[str], optional): common node for linearize. Defaults to None.
|
||||
eps (float, optional): epsilon for memory decay. Defaults to 0.02.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated graph module
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain = _normalize_flops(chain, flops)
|
||||
# currently we view loss as an op without expense
|
||||
chain.cbweight.append(0)
|
||||
chain.cweight.append(0)
|
||||
chain.fwd_mem_tmp.append(0)
|
||||
chain.bwd_mem_tmp.append(0)
|
||||
chain.fweight.append(0)
|
||||
chain.bweight.append(0)
|
||||
|
||||
solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots)
|
||||
first_state = (0, 0, 0, 0, False)
|
||||
sequence = solver.pofo_rec(first_state)
|
||||
if sequence == None:
|
||||
print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!")
|
||||
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
return gm
|
|
@ -4,7 +4,7 @@ from colossalai.fx.graph_module import ColoGraphModule
|
|||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
|
||||
|
@ -110,10 +110,6 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
|||
return sequence
|
||||
|
||||
|
||||
def _discretize(mem_unit, values):
|
||||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
def _fwd_xbar(node: List[Node]) -> int:
|
||||
"""Get the forward xbar of a node
|
||||
|
||||
|
@ -204,7 +200,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
|||
return bwd_mem_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
|
||||
def _construct_chain(node_list: List[List[Node]], input) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
bwd_time = []
|
||||
|
@ -226,11 +222,6 @@ def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain
|
|||
# currently we view loss backward temp as zero
|
||||
tmp_bwd.append(0)
|
||||
|
||||
xbar_sizes = _discretize(mem_unit, xbar_sizes)
|
||||
x_sizes = _discretize(mem_unit, x_sizes)
|
||||
tmp_fwd = _discretize(mem_unit, tmp_fwd)
|
||||
tmp_bwd = _discretize(mem_unit, tmp_bwd)
|
||||
|
||||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||
|
||||
|
||||
|
@ -345,7 +336,9 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
mem_limit -= parameter_size(gm)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain._discretize(mem_unit)
|
||||
opt_table = _compute_table(chain, mem_slots)
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
import math
|
||||
|
||||
|
||||
def _discretize(mem_unit, values):
|
||||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
|
||||
|
@ -25,6 +32,12 @@ class Chain:
|
|||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
def _discretize(self, mem_unit):
|
||||
self.cweight = _discretize(mem_unit, self.cweight)
|
||||
self.cbweight = _discretize(mem_unit, self.cbweight)
|
||||
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
|
||||
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
|
||||
|
||||
|
||||
class Operation:
|
||||
|
||||
|
@ -35,6 +48,32 @@ class Operation:
|
|||
self.index += value
|
||||
|
||||
|
||||
class Offload(Operation):
|
||||
|
||||
def __init__(self, index, has_bar=False) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Off"
|
||||
if has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
|
||||
class Prefetch(Operation):
|
||||
|
||||
def __init__(self, index, has_bar=False) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Pre"
|
||||
if has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.name}_{self.index}"
|
||||
|
||||
|
||||
class Forward(Operation):
|
||||
|
||||
def __init__(self, index):
|
|
@ -3,7 +3,7 @@ import torchvision.models as tm
|
|||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue