From ac3739930d36580d86ed3a04445a11d6910951c0 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Mon, 2 Jan 2023 16:26:12 +0800 Subject: [PATCH] [autoparallel] modify construct chain in rotor solver (#2254) --- .../auto_parallel/checkpoint/ckpt_solver_rotor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 72bc67e02..6ef53c9d1 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -128,16 +128,18 @@ class CheckpointSolverRotor(CheckpointSolverBase): xbar = 0 ftime = 0 btime = 0 + fwd_mem_peak = 0 for n in node: assert isinstance(n, Node), f'{n} is not a Node' xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) btime += max(calculate_bwd_time(n), 1.0) x = calculate_fwd_out(node[-1]) xbar = max(x, xbar) - ftmp = cls._extract_ftmp(node) + ftmp = fwd_mem_peak - xbar btmp = cls._extract_btmp(node) return ftime, btime, x, xbar, ftmp, btmp @@ -151,10 +153,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): return input_tensors @staticmethod - def _extract_ftmp(node: List[Node]) -> int: - """Extract ftmp from a list of nodes""" - n = node[-1] - return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) + def _extract_unused_output(node: Node) -> int: + """Extract unused output from `torch.fx.Node`""" + return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) @staticmethod def _extract_btmp(node: List[Node]) -> int: