Browse Source

[autoparallel] modify construct chain in rotor solver (#2254)

pull/2257/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
ac3739930d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py

11
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py

@ -128,16 +128,18 @@ class CheckpointSolverRotor(CheckpointSolverBase):
xbar = 0
ftime = 0
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
ftmp = fwd_mem_peak - xbar
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp
@ -151,10 +153,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return input_tensors
@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:

Loading…
Cancel
Save