mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] modify construct chain in rotor solver (#2254)
parent
ab38aebace
commit
ac3739930d
|
@ -128,16 +128,18 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
xbar = 0
|
xbar = 0
|
||||||
ftime = 0
|
ftime = 0
|
||||||
btime = 0
|
btime = 0
|
||||||
|
fwd_mem_peak = 0
|
||||||
for n in node:
|
for n in node:
|
||||||
assert isinstance(n, Node), f'{n} is not a Node'
|
assert isinstance(n, Node), f'{n} is not a Node'
|
||||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||||
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||||
# minimum flop count is required
|
# minimum flop count is required
|
||||||
ftime += max(calculate_fwd_time(n), 1.0)
|
ftime += max(calculate_fwd_time(n), 1.0)
|
||||||
btime += max(calculate_bwd_time(n), 1.0)
|
btime += max(calculate_bwd_time(n), 1.0)
|
||||||
|
|
||||||
x = calculate_fwd_out(node[-1])
|
x = calculate_fwd_out(node[-1])
|
||||||
xbar = max(x, xbar)
|
xbar = max(x, xbar)
|
||||||
ftmp = cls._extract_ftmp(node)
|
ftmp = fwd_mem_peak - xbar
|
||||||
btmp = cls._extract_btmp(node)
|
btmp = cls._extract_btmp(node)
|
||||||
return ftime, btime, x, xbar, ftmp, btmp
|
return ftime, btime, x, xbar, ftmp, btmp
|
||||||
|
|
||||||
|
@ -151,10 +153,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
return input_tensors
|
return input_tensors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_ftmp(node: List[Node]) -> int:
|
def _extract_unused_output(node: Node) -> int:
|
||||||
"""Extract ftmp from a list of nodes"""
|
"""Extract unused output from `torch.fx.Node`"""
|
||||||
n = node[-1]
|
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_btmp(node: List[Node]) -> int:
|
def _extract_btmp(node: List[Node]) -> int:
|
||||||
|
|
Loading…
Reference in New Issue