Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into dev0116

pull/2499/head
jiaruifang 2023-01-18 18:43:11 +08:00
commit 7f822a5c45
1 changed files with 7 additions and 2 deletions

View File

@ -22,8 +22,13 @@ if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
graph.set_codegen(ActivationCheckpointCodeGen())
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = 'GraphModule',
ckpt_codegen: bool = True):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):