mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] modify construct chain in rotor solver (#2254)
parent
ab38aebace
commit
ac3739930d
|
@ -128,16 +128,18 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
xbar = 0
|
||||
ftime = 0
|
||||
btime = 0
|
||||
fwd_mem_peak = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||
# minimum flop count is required
|
||||
ftime += max(calculate_fwd_time(n), 1.0)
|
||||
btime += max(calculate_bwd_time(n), 1.0)
|
||||
|
||||
x = calculate_fwd_out(node[-1])
|
||||
xbar = max(x, xbar)
|
||||
ftmp = cls._extract_ftmp(node)
|
||||
ftmp = fwd_mem_peak - xbar
|
||||
btmp = cls._extract_btmp(node)
|
||||
return ftime, btime, x, xbar, ftmp, btmp
|
||||
|
||||
|
@ -151,10 +153,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
return input_tensors
|
||||
|
||||
@staticmethod
|
||||
def _extract_ftmp(node: List[Node]) -> int:
|
||||
"""Extract ftmp from a list of nodes"""
|
||||
n = node[-1]
|
||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||
def _extract_unused_output(node: Node) -> int:
|
||||
"""Extract unused output from `torch.fx.Node`"""
|
||||
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||
|
||||
@staticmethod
|
||||
def _extract_btmp(node: List[Node]) -> int:
|
||||
|
|
Loading…
Reference in New Issue