diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 2d6a71f19..ebb9975f2 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -22,8 +22,13 @@ if COLOGM: class ColoGraphModule(GraphModule): - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): - graph.set_codegen(ActivationCheckpointCodeGen()) + def __init__(self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = 'GraphModule', + ckpt_codegen: bool = True): + if ckpt_codegen: + graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals):