mirror of https://github.com/hpcaitech/ColossalAI
[fx] Fix ckpt functions' definitions in forward (#1476)
* [fx] fix defining ckpt functions inside forward * [fx] Modify activation checkpoint codegen and add ColoGraphModule * [fx] some modification * some modifications * some modifications * some modifications * some modifications * some code modificationspull/1477/head
parent
bb5f5289e0
commit
1f2e547f7a
|
@ -1,12 +1,13 @@
|
|||
import colossalai
|
||||
import torch
|
||||
from typing import List, Callable, Any, Tuple, Dict
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
CODEGEN_AVAILABLE = False
|
||||
|
||||
|
@ -89,7 +90,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
|||
"""
|
||||
Generate the checkpoint function definition
|
||||
"""
|
||||
return f"def checkpoint_{label}({', '.join(free_vars)}):"
|
||||
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
|
||||
|
||||
|
||||
def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
|
@ -105,10 +106,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
|||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
# find the activation checkpoint regions
|
||||
ckpt_regions = _find_ckpt_regions(nodes)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
|
@ -133,27 +134,27 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
|||
if idx in start_idx:
|
||||
label = start_idx.index(idx)
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
|
||||
body.append(f'{ckpt_fn_def}\n')
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
within_ckpt_region = True
|
||||
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
emit_node_func(node)
|
||||
|
||||
# add indentation to the emmited node
|
||||
# NOTE: currently we separate body and ckpt_func definition
|
||||
if within_ckpt_region:
|
||||
body[-1] = ' ' + body[-1]
|
||||
|
||||
# delete unused values
|
||||
delete_unused_value_func(node)
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
if idx in end_idx:
|
||||
# if this is the last node of the ckpt region
|
||||
# generate return statement
|
||||
label = end_idx.index(idx)
|
||||
return_statement = _gen_ckpt_output(output_vars[label])
|
||||
return_statement = f' {return_statement}\n'
|
||||
body.append(return_statement)
|
||||
return_statement = f' {return_statement}\n\n'
|
||||
ckpt_func.append(return_statement)
|
||||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
|
@ -221,6 +222,9 @@ if CODEGEN_AVAILABLE:
|
|||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
|
@ -287,7 +291,8 @@ if CODEGEN_AVAILABLE:
|
|||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
def delete_unused_values(user: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
|
@ -305,7 +310,8 @@ if CODEGEN_AVAILABLE:
|
|||
else:
|
||||
body.append('\n')
|
||||
|
||||
def emit_node(node: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
|
@ -371,7 +377,8 @@ if CODEGEN_AVAILABLE:
|
|||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
|
||||
# Modified for activation checkpointing
|
||||
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values)
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
@ -395,7 +402,8 @@ if CODEGEN_AVAILABLE:
|
|||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = prologue + "\n import colossalai"
|
||||
prologue = ''.join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
|
@ -444,6 +452,9 @@ else:
|
|||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
|
@ -484,7 +495,8 @@ else:
|
|||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
def delete_unused_values(user: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
|
@ -502,7 +514,8 @@ else:
|
|||
else:
|
||||
body.append('\n')
|
||||
|
||||
def emit_node(node: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
|
@ -562,7 +575,8 @@ else:
|
|||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
|
||||
# Modified for activation checkpointing
|
||||
emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values)
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
@ -587,6 +601,8 @@ else:
|
|||
else:
|
||||
wrap_stmts = ''
|
||||
|
||||
ckpt_func = ''.join(ckpt_func)
|
||||
|
||||
# If the original function didn't have self as its first argument, we
|
||||
# would have added it.
|
||||
if len(orig_args) == 0 or orig_args[0] != 'self':
|
||||
|
@ -600,7 +616,7 @@ else:
|
|||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{ckpt_func}
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
import colossalai
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.module import _addindent
|
||||
from typing import Type, Dict, List, Any, Union, Optional, Set
|
||||
from pathlib import Path
|
||||
try:
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
|
||||
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
|
||||
COLOGM = True
|
||||
except:
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
COLOGM = False
|
||||
|
||||
if COLOGM:
|
||||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
"""Bind checkpoint functions to ColoGraphModule
|
||||
We need to bind our checkpoint functions to the GraphModule so
|
||||
that we could correctly use self.checkpoint for GraphModule forward
|
||||
"""
|
||||
ckpt_code = "\n".join(ckpt_def)
|
||||
globals_copy = globals.copy()
|
||||
_exec_with_source(ckpt_code, globals_copy)
|
||||
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
|
||||
for func in func_list:
|
||||
tmp_func = globals_copy[func]
|
||||
setattr(self, func, tmp_func.__get__(self, self.__class__))
|
||||
del globals_copy[func]
|
||||
|
||||
def recompile(self) -> PythonCode:
|
||||
"""
|
||||
Recompile this GraphModule from its ``graph`` attribute. This should be
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
_code_list = self._code.split("\n")
|
||||
_fwd_def = [item for item in _code_list if "def forward" in item][0]
|
||||
_fwd_idx = _code_list.index(_fwd_def)
|
||||
ckpt_def = _code_list[:_fwd_idx]
|
||||
self._code = "\n".join(_code_list[_fwd_idx:])
|
||||
|
||||
self.bind(ckpt_def, python_code.globals)
|
||||
|
||||
cls = type(self)
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
# to wrap. If it does, save it in order to have wrapped_call invoke it.
|
||||
# If it does not, wrapped_call can use a dynamic call to super() instead.
|
||||
# In most cases, super().__call__ should be torch.nn.Module.__call__.
|
||||
# We do not want to hold a reference to Module.__call__ here; doing so will
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
|
||||
cls.__call__ = call_wrapped
|
||||
|
||||
# reset self._code to original src, otherwise to_folder will be wrong
|
||||
self._code = python_code.src
|
||||
return python_code
|
||||
|
||||
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
|
||||
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
|
||||
imported with ``from <folder> import <module_name>``
|
||||
|
||||
Args:
|
||||
|
||||
folder (Union[str, os.PathLike]): The folder to write the code out to
|
||||
|
||||
module_name (str): Top-level name to use for the ``Module`` while
|
||||
writing out the code
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
model_str = f"""
|
||||
import torch
|
||||
from torch.nn import *
|
||||
import colossalai
|
||||
|
||||
|
||||
class {module_name}(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
"""
|
||||
|
||||
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
|
||||
safe_reprs = [
|
||||
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
||||
]
|
||||
if type(module) in safe_reprs:
|
||||
return f"{module.__repr__()}"
|
||||
else:
|
||||
return None
|
||||
|
||||
blobified_modules = []
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
for buffer_name, buffer in self._buffers.items():
|
||||
if buffer is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
|
||||
|
||||
for param_name, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
|
||||
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
||||
|
||||
else:
|
||||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
super().__init__(root, graph, class_name)
|
|
@ -7,6 +7,7 @@ import torchvision.models as tm
|
|||
from torch.fx import GraphModule
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy
|
||||
from colossalai.utils import free_port
|
||||
|
@ -72,7 +73,7 @@ def _run_ckpt_solver(rank):
|
|||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
graph = tracer.trace(root=m)
|
||||
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
|
@ -102,7 +103,7 @@ def _run_ckpt_solver_torch11(rank):
|
|||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
graph = tracer.trace(root=m)
|
||||
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
gm = solver(gm)
|
||||
|
@ -114,10 +115,12 @@ def _run_ckpt_solver_torch11(rank):
|
|||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ckpt_solver()
|
||||
test_ckpt_solver_torch11()
|
||||
_run_ckpt_solver(rank=0)
|
||||
# test_ckpt_solver()
|
||||
# test_ckpt_solver_torch11()
|
||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.fx import ColoTracer
|
|||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -46,7 +47,7 @@ class MyModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.mlp1 = MLP()
|
||||
self.relu = relu()
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
y1, y2 = checkpoint(self.mlp1, x)
|
||||
|
@ -56,6 +57,7 @@ class MyModule(torch.nn.Module):
|
|||
return F.relu(x, inplace=True)
|
||||
|
||||
y4 = checkpoint(ckpt2, x)
|
||||
y4 = self.linear2(y4)
|
||||
return y1 + y2 + y3 + y4
|
||||
|
||||
|
||||
|
@ -91,15 +93,15 @@ def _run_act_ckpt_codegen(rank):
|
|||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
gm = GraphModule(model, graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data)
|
||||
|
@ -145,14 +147,14 @@ def _run_act_ckpt_python_code_torch11(rank):
|
|||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
gm = GraphModule(model, graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data)
|
||||
|
@ -162,11 +164,10 @@ def _run_act_ckpt_python_code_torch11(rank):
|
|||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_act_ckpt_codegen()
|
||||
test_act_ckpt_python_code_torch11()
|
||||
_run_act_ckpt_codegen(rank=0)
|
||||
|
|
Loading…
Reference in New Issue