diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index f7de4987c..2f2727215 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -142,8 +142,6 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i bwd_time.append(0) - fwd_time = _discretize(mem_unit, fwd_time) - bwd_time = _discretize(mem_unit, bwd_time) xbar_sizes = _discretize(mem_unit, xbar_sizes) x_sizes = _discretize(mem_unit, x_sizes) tmp_fwd = _discretize(mem_unit, tmp_fwd)