|
|
|
@ -128,16 +128,18 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|
|
|
|
xbar = 0 |
|
|
|
|
ftime = 0 |
|
|
|
|
btime = 0 |
|
|
|
|
fwd_mem_peak = 0 |
|
|
|
|
for n in node: |
|
|
|
|
assert isinstance(n, Node), f'{n} is not a Node' |
|
|
|
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) |
|
|
|
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) |
|
|
|
|
# minimum flop count is required |
|
|
|
|
ftime += max(calculate_fwd_time(n), 1.0) |
|
|
|
|
btime += max(calculate_bwd_time(n), 1.0) |
|
|
|
|
|
|
|
|
|
x = calculate_fwd_out(node[-1]) |
|
|
|
|
xbar = max(x, xbar) |
|
|
|
|
ftmp = cls._extract_ftmp(node) |
|
|
|
|
ftmp = fwd_mem_peak - xbar |
|
|
|
|
btmp = cls._extract_btmp(node) |
|
|
|
|
return ftime, btime, x, xbar, ftmp, btmp |
|
|
|
|
|
|
|
|
@ -151,10 +153,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|
|
|
|
return input_tensors |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _extract_ftmp(node: List[Node]) -> int: |
|
|
|
|
"""Extract ftmp from a list of nodes""" |
|
|
|
|
n = node[-1] |
|
|
|
|
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) |
|
|
|
|
def _extract_unused_output(node: Node) -> int: |
|
|
|
|
"""Extract unused output from `torch.fx.Node`""" |
|
|
|
|
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _extract_btmp(node: List[Node]) -> int: |
|
|
|
|