import torch import torchvision.models as tm from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.algorithms import solver_rotor, linearize from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd import pytest try: from colossalai.fx.codegen import ActivationCheckpointCodeGen with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False @pytest.mark.skip(reason='TODO: modify calculations in rotor') @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() for M, budgets in MODEL_DICT.items(): for budget in budgets: model = M() graph = tracer.trace(model) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) op_list = op_list[:op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): if in_ckpt: if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" continue if isinstance(op, ForwardEnable): for n in node_list[idx]: assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" in_ckpt = False ckpt_idx += 1 continue if isinstance(op, ForwardCheck): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" continue else: if isinstance(op, ForwardCheck): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" del model del gm del node_list @pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") def test_linearize_torch11(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() for M, budgets in MODEL_DICT.items(): for budget in budgets: model = M() graph = tracer.trace(model) gm = ColoGraphModule(model, graph, model.__class__.__name__) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) op_list = op_list[:op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): if in_ckpt: if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" continue if isinstance(op, ForwardEnable): for n in node_list[idx]: assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" in_ckpt = False ckpt_idx += 1 continue if isinstance(op, ForwardCheck): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" continue else: if isinstance(op, ForwardCheck): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" del model del gm del node_list if __name__ == "__main__": test_linearize()