mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into dev0116
commit
7f822a5c45
|
@ -22,8 +22,13 @@ if COLOGM:
|
|||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: Graph,
|
||||
class_name: str = 'GraphModule',
|
||||
ckpt_codegen: bool = True):
|
||||
if ckpt_codegen:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
|
|
Loading…
Reference in New Issue