From 31d2f03d27377c7981c8df03c06390d511d6cb78 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Wed, 12 Oct 2022 15:21:58 +0800 Subject: [PATCH] [autoparallel] fix C version rotor inconsistency (#1691) --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 30 ++++++++++++++--- .../fx/passes/algorithms/dynamic_programs.c | 33 ++++++++++--------- .../test_C_solver_consistency.py | 13 ++++++-- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index f5d7dad27..01c3bdb35 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -10,6 +10,9 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.logging import get_dist_logger +# global vairable to indicate whether the solver is failed +SOLVER_FAILED = False + # this is the python compute table code from rotor # https://gitlab.inria.fr/hiepacs/rotor @@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table): opt, what = opt_table sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) if opt[cmem][lmin][lmax] == float("inf"): - raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, - lmax=lmax, - cmem=cmem)) + # using logger to annonce that the solver is failed + logger = get_dist_logger() + logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, + lmax=lmax, + cmem=cmem)) + + # set global indicater SOLVER_FAILED to True + global SOLVER_FAILED + SOLVER_FAILED = True + return sequence + if lmin == lmax: if lmin == chain.length: sequence.insert(Loss()) @@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule, # found sequence sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) - _annotate_from_sequence(sequence, node_list) + + # if solver failed, we don't need to annotate the graph + if not SOLVER_FAILED: + _annotate_from_sequence(sequence, node_list) # set __sequence__ attribute to GraphModule - setattr(gm, "__sequence__", sequence) + if SOLVER_FAILED: + setattr(gm, "__sequence__", None) + else: + setattr(gm, "__sequence__", sequence) + + # set __opttable__ attribute to GraphModule + setattr(gm, "__opttable__", opt_table[0]) gm.recompile() return gm diff --git a/colossalai/fx/passes/algorithms/dynamic_programs.c b/colossalai/fx/passes/algorithms/dynamic_programs.c index 4bea393a7..3efad5840 100644 --- a/colossalai/fx/passes/algorithms/dynamic_programs.c +++ b/colossalai/fx/passes/algorithms/dynamic_programs.c @@ -94,13 +94,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { OPT(m, i, i) = INFINITY; for (long m = 0; m <= mmax; ++m) - for (long i = 0; i <= chain_length; ++i) { - long maxCostFWD = 0; - for (long l = i + 1; l <= chain_length; ++l) { - long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i]; - if (l > i + 1) { - maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]); - mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD); + for (long d = 1; d <= chain_length; ++d) { + for (long i = 0; i <= chain_length - d; ++i) { + long idx = i + d; + long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]; + if (idx > i + 1) { + long maxCostFWD = 0; + for (long j = i + 1; j < idx; j++) { + maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]); + } + mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD); } if ((m >= mmin)) { long bestLeaf = -1; @@ -108,10 +111,10 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { double bestLeafCost = INFINITY; /// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j = /// i+1 - for (long j = i + 1; j <= l; ++j) { + for (long j = i + 1; j <= idx; ++j) { sumFw += fw[j - 1]; if (m >= cw[j]) { - double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1); + double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1); if (cost < bestLeafCost) { bestLeafCost = cost; bestLeaf = j; @@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) { } double chainCost = INFINITY; if (m >= cbw[i + 1]) - chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l); + chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx); if (bestLeafCost <= chainCost) { - OPT(m, i, l) = bestLeafCost; - WHAT(m, i, l) = bestLeaf; + OPT(m, i, idx) = bestLeafCost; + WHAT(m, i, idx) = bestLeaf; } else { - OPT(m, i, l) = chainCost; - WHAT(m, i, l) = -1; + OPT(m, i, idx) = chainCost; + WHAT(m, i, idx) = -1; } } else - OPT(m, i, l) = INFINITY; + OPT(m, i, idx) = INFINITY; } } diff --git a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py index c638e7ac2..41ed6fd8c 100644 --- a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py @@ -26,7 +26,7 @@ except: def _run_C_solver_consistency_test(rank=0): colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') - for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]: + for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() data = torch.rand(128, 3, 224, 224, device='meta') @@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0): # python solver gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True) sequence_python: Sequence = copy.deepcopy(gm.__sequence__) + opt_python = copy.deepcopy(gm.__opttable__) # C solver gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024) sequence_C: Sequence = copy.deepcopy(gm.__sequence__) + opt_C = copy.deepcopy(gm.__opttable__) + + # make sure the opt_tables are the same + for m in range(len(opt_python)): + for d in range(1, len(opt_python[0])): + for i in range(len(opt_python[0]) - d): + assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ + f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" sequence_python = sequence_python.list_operations() sequence_C = sequence_C.list_operations() - # make sure the solutions are the same + # make sure the sequences are the same assert len(sequence_python) == len(sequence_C) and \ all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))