mirror of https://github.com/hpcaitech/ColossalAI
[fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__. * code stylepull/2499/head
parent
e327e95144
commit
5db3a5bf42
|
@ -22,8 +22,13 @@ if COLOGM:
|
||||||
|
|
||||||
class ColoGraphModule(GraphModule):
|
class ColoGraphModule(GraphModule):
|
||||||
|
|
||||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
def __init__(self,
|
||||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||||
|
graph: Graph,
|
||||||
|
class_name: str = 'GraphModule',
|
||||||
|
ckpt_codegen: bool = True):
|
||||||
|
if ckpt_codegen:
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
super().__init__(root, graph, class_name)
|
super().__init__(root, graph, class_name)
|
||||||
|
|
||||||
def bind(self, ckpt_def, globals):
|
def bind(self, ckpt_def, globals):
|
||||||
|
|
Loading…
Reference in New Issue