diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 2f2727215..ce209b674 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,6 +1,7 @@ from typing import List, Set, Tuple, Dict import torch from torch.fx import GraphModule, Node +from colossalai.fx.graph_module import ColoGraphModule import math from .linearize import linearize from .utils import * @@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) for node in node_dict[key]: - fwd_time[-1] += node.__flops__ + fwd_time[-1] += max(node.__flops__, 1) # currently we haven't patched the backward flops count - bwd_time[-1] += node.__flops__ * 2 + bwd_time[-1] += max(node.__flops__ * 2, 2) xbar_sizes[-1] += node.__activation__ @@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G elif isinstance(op, ForwardEnable): in_ckpt = False - for idx in ckpt_region: - for node in node_dict[idx]: + for node_idx in ckpt_region: + for node in node_dict[node_idx]: setattr(node, "activation_checkpoint", ckpt_idx) ckpt_idx += 1 ckpt_region = [] elif isinstance(op, ForwardCheck): - for idx in ckpt_region: - for node in node_dict[idx]: + for node_idx in ckpt_region: + for node in node_dict[node_idx]: setattr(node, "activation_checkpoint", ckpt_idx) ckpt_idx += 1 @@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G ckpt_region.append(idx) -def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule: +def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule: + """solver that automatically find activation checkpoint in rotor's manner + + Args: + gm (ColoGraphModule): ColoGraphModule generated by tracing model. + data (torch.Tensor): input data. + mem_limit (int): memory budget in Byte. + mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + + Returns: + ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute + """ + node_dict = linearize(gm) mem_unit = mem_limit // mem_slots MetaInfoProp(gm).run(data) @@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: opt_table = _compute_table(chain, mem_slots) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_dict) + + # set __sequence__ attribute to GraphModule + setattr(gm, "__sequence__", sequence) return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 1d6352d07..ea9aec43d 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call def _run_ckpt_solver(rank): colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') - MODEL_LIST = [tm.resnet18, tm.densenet121] + MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(2, 3, 32, 32, device='meta') + data = torch.rand(8, 3, 224, 224, device='meta') for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -95,13 +95,13 @@ def test_ckpt_solver(): def _run_ckpt_solver_torch11(rank): colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') - MODEL_LIST = [tm.resnet18, tm.densenet121] + MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(2, 3, 32, 32, device='meta') + data = torch.rand(8, 3, 32, 32, device='meta') for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5)