[fx] Fix ckpt functions' definitions in forward (#1476)

* [fx] fix defining ckpt functions inside forward

* [fx] Modify activation checkpoint codegen and add ColoGraphModule

* [fx] some modification

* some modifications

* some modifications

* some modifications

* some modifications

* some code modifications
pull/1477/head
Boyuan Yao 2022-08-22 16:59:54 +08:00 committed by GitHub
parent bb5f5289e0
commit 1f2e547f7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 39 deletions

View File

@ -1,12 +1,13 @@
import colossalai
import torch
from typing import List, Callable, Any, Tuple, Dict
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
CODEGEN_AVAILABLE = False
@ -89,7 +90,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(free_vars)}):"
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str:
@ -105,10 +106,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes)
start_idx = [item[0] for item in ckpt_regions]
@ -133,27 +134,27 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
body.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node_func(node)
# add indentation to the emmited node
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
body[-1] = ' ' + body[-1]
# delete unused values
delete_unused_value_func(node)
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if idx in end_idx:
# if this is the last node of the ckpt region
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n'
body.append(return_statement)
return_statement = f' {return_statement}\n\n'
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
@ -221,6 +222,9 @@ if CODEGEN_AVAILABLE:
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
@ -287,7 +291,8 @@ if CODEGEN_AVAILABLE:
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
@ -305,7 +310,8 @@ if CODEGEN_AVAILABLE:
else:
body.append('\n')
def emit_node(node: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
@ -371,7 +377,8 @@ if CODEGEN_AVAILABLE:
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values)
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -395,7 +402,8 @@ if CODEGEN_AVAILABLE:
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai"
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
@ -444,6 +452,9 @@ else:
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
@ -484,7 +495,8 @@ else:
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
@ -502,7 +514,8 @@ else:
else:
body.append('\n')
def emit_node(node: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
@ -562,7 +575,8 @@ else:
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values)
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -587,6 +601,8 @@ else:
else:
wrap_stmts = ''
ckpt_func = ''.join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self':
@ -600,7 +616,7 @@ else:
fn_code = f"""
{wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}"""
return PythonCode(fn_code, globals_)

View File

@ -0,0 +1,158 @@
import os
import warnings
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
COLOGM = True
except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
"""Bind checkpoint functions to ColoGraphModule
We need to bind our checkpoint functions to the GraphModule so
that we could correctly use self.checkpoint for GraphModule forward
"""
ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))
del globals_copy[func]
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
# To split ckpt functions code and forward code
_code_list = self._code.split("\n")
_fwd_def = [item for item in _code_list if "def forward" in item][0]
_fwd_idx = _code_list.index(_fwd_def)
ckpt_def = _code_list[:_fwd_idx]
self._code = "\n".join(_code_list[_fwd_idx:])
self.bind(ckpt_def, python_code.globals)
cls = type(self)
cls.forward = _forward_from_src(self._code, python_code.globals)
# Determine whether this class explicitly defines a __call__ implementation
# to wrap. If it does, save it in order to have wrapped_call invoke it.
# If it does not, wrapped_call can use a dynamic call to super() instead.
# In most cases, super().__call__ should be torch.nn.Module.__call__.
# We do not want to hold a reference to Module.__call__ here; doing so will
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
cls.__call__ = call_wrapped
# reset self._code to original src, otherwise to_folder will be wrong
self._code = python_code.src
return python_code
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
module_name (str): Top-level name to use for the ``Module`` while
writing out the code
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
tab = " " * 4
# we add import colossalai here
model_str = f"""
import torch
from torch.nn import *
import colossalai
class {module_name}(torch.nn.Module):
def __init__(self):
super().__init__()
"""
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
else:
return None
blobified_modules = []
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
else:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)

View File

@ -7,6 +7,7 @@ import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.utils import free_port
@ -72,7 +73,7 @@ def _run_ckpt_solver(rank):
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
@ -102,7 +103,7 @@ def _run_ckpt_solver_torch11(rank):
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
@ -114,10 +115,12 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
if __name__ == '__main__':
test_ckpt_solver()
test_ckpt_solver_torch11()
_run_ckpt_solver(rank=0)
# test_ckpt_solver()
# test_ckpt_solver_torch11()

View File

@ -9,6 +9,7 @@ from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -46,7 +47,7 @@ class MyModule(torch.nn.Module):
super().__init__()
self.mlp1 = MLP()
self.relu = relu()
self.linear3 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x):
y1, y2 = checkpoint(self.mlp1, x)
@ -56,6 +57,7 @@ class MyModule(torch.nn.Module):
return F.relu(x, inplace=True)
y4 = checkpoint(ckpt2, x)
y4 = self.linear2(y4)
return y1 + y2 + y3 + y4
@ -91,15 +93,15 @@ def _run_act_ckpt_codegen(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data)
@ -145,14 +147,14 @@ def _run_act_ckpt_python_code_torch11(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data)
@ -162,11 +164,10 @@ def _run_act_ckpt_python_code_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__':
test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()
_run_act_ckpt_codegen(rank=0)