@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch import Tensor
from torch . fx import Graph , Node
from torch . fx import Graph , Node
from colossalai . auto_parallel . passes . runtime_apply_pass import runtime_apply , runtime_comm_spec_apply
from colossalai . fx . codegen . activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai . fx . codegen . activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai . fx . profiler import (
from colossalai . fx . profiler import (
activation_size ,
activation_size ,
@ -131,8 +132,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
fwd_mem_peak = 0
fwd_mem_peak = 0
for n in node :
for n in node :
assert isinstance ( n , Node ) , f ' { n } is not a Node '
assert isinstance ( n , Node ) , f ' { n } is not a Node '
xbar + = calculate_fwd_tmp ( n ) + calculate_fwd_out ( n )
if n . target == runtime_apply or n . target == runtime_comm_spec_apply :
fwd_mem_peak = max ( fwd_mem_peak , xbar + n . meta [ ' fwd_mem_tmp ' ] + cls . _extract_unused_output ( n ) )
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar + = n . meta [ ' fwd_mem_out ' ]
fwd_mem_peak = max ( fwd_mem_peak , xbar + n . meta [ ' fwd_mem_tmp ' ] )
else :
xbar + = calculate_fwd_tmp ( n ) + calculate_fwd_out ( n )
fwd_mem_peak = max ( fwd_mem_peak , xbar + n . meta [ ' fwd_mem_tmp ' ] + cls . _extract_unused_output ( n ) )
# minimum flop count is required
# minimum flop count is required
ftime + = max ( calculate_fwd_time ( n ) , 1.0 )
ftime + = max ( calculate_fwd_time ( n ) , 1.0 )
btime + = max ( calculate_bwd_time ( n ) , 1.0 )
btime + = max ( calculate_bwd_time ( n ) , 1.0 )