mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix C version rotor inconsistency (#1691)
parent
363fc2861a
commit
31d2f03d27
|
@ -10,6 +10,9 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
|
|||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
# global vairable to indicate whether the solver is failed
|
||||
SOLVER_FAILED = False
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
|
@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
|||
opt, what = opt_table
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
|
||||
if opt[cmem][lmin][lmax] == float("inf"):
|
||||
raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
|
||||
# set global indicater SOLVER_FAILED to True
|
||||
global SOLVER_FAILED
|
||||
SOLVER_FAILED = True
|
||||
return sequence
|
||||
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
|
@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
|
||||
# found sequence
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# if solver failed, we don't need to annotate the graph
|
||||
if not SOLVER_FAILED:
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
if SOLVER_FAILED:
|
||||
setattr(gm, "__sequence__", None)
|
||||
else:
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
|
||||
# set __opttable__ attribute to GraphModule
|
||||
setattr(gm, "__opttable__", opt_table[0])
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
|
|
@ -94,13 +94,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
|||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
long maxCostFWD = 0;
|
||||
for (long l = i + 1; l <= chain_length; ++l) {
|
||||
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (l > i + 1) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
|
||||
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
|
||||
for (long d = 1; d <= chain_length; ++d) {
|
||||
for (long i = 0; i <= chain_length - d; ++i) {
|
||||
long idx = i + d;
|
||||
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (idx > i + 1) {
|
||||
long maxCostFWD = 0;
|
||||
for (long j = i + 1; j < idx; j++) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
|
||||
}
|
||||
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
|
@ -108,10 +111,10 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
|||
double bestLeafCost = INFINITY;
|
||||
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
|
||||
/// i+1
|
||||
for (long j = i + 1; j <= l; ++j) {
|
||||
for (long j = i + 1; j <= idx; ++j) {
|
||||
sumFw += fw[j - 1];
|
||||
if (m >= cw[j]) {
|
||||
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
|
||||
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
|
@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
|||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= cbw[i + 1])
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
OPT(m, i, l) = bestLeafCost;
|
||||
WHAT(m, i, l) = bestLeaf;
|
||||
OPT(m, i, idx) = bestLeafCost;
|
||||
WHAT(m, i, idx) = bestLeaf;
|
||||
} else {
|
||||
OPT(m, i, l) = chainCost;
|
||||
WHAT(m, i, l) = -1;
|
||||
OPT(m, i, idx) = chainCost;
|
||||
WHAT(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
OPT(m, i, l) = INFINITY;
|
||||
OPT(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ except:
|
|||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
|
||||
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
|
||||
model = M()
|
||||
data = torch.rand(128, 3, 224, 224, device='meta')
|
||||
|
||||
|
@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0):
|
|||
# python solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
|
||||
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
opt_python = copy.deepcopy(gm.__opttable__)
|
||||
|
||||
# C solver
|
||||
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
|
||||
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
|
||||
opt_C = copy.deepcopy(gm.__opttable__)
|
||||
|
||||
# make sure the opt_tables are the same
|
||||
for m in range(len(opt_python)):
|
||||
for d in range(1, len(opt_python[0])):
|
||||
for i in range(len(opt_python[0]) - d):
|
||||
assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \
|
||||
f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}"
|
||||
|
||||
sequence_python = sequence_python.list_operations()
|
||||
sequence_C = sequence_C.list_operations()
|
||||
|
||||
# make sure the solutions are the same
|
||||
# make sure the sequences are the same
|
||||
assert len(sequence_python) == len(sequence_C) and \
|
||||
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))
|
||||
|
||||
|
|
Loading…
Reference in New Issue