mirror of https://github.com/hpcaitech/ColossalAI
[fx] allow native ckpt trace and codegen. (#2438)
parent
41429b9b28
commit
c41e59e5ad
|
@ -1,17 +1,21 @@
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.module import _addindent
|
||||
from typing import Type, Dict, List, Any, Union, Optional, Set
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
|
||||
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
|
||||
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
COLOGM = True
|
||||
except:
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
COLOGM = False
|
||||
|
||||
if COLOGM:
|
||||
|
@ -19,6 +23,7 @@ if COLOGM:
|
|||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
|
|
|
@ -13,6 +13,7 @@ def symbolic_trace(
|
|||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt=False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Symbolic tracing API
|
||||
|
@ -49,6 +50,6 @@ def symbolic_trace(
|
|||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||
|
||||
"""
|
||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import enum
|
||||
import functools
|
||||
import operator
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
@ -286,7 +286,6 @@ class ColoTracer(Tracer):
|
|||
self.graph.lint()
|
||||
return self.graph
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_activation_checkpoint(self, enabled: bool):
|
||||
if enabled:
|
||||
|
@ -316,7 +315,6 @@ class ColoTracer(Tracer):
|
|||
# recover the checkpoint function upon exit
|
||||
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
|
||||
|
||||
|
||||
def _post_check(self, non_concrete_arg_names: Set[str]):
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
|
@ -385,18 +383,23 @@ def symbolic_trace(
|
|||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt=False,
|
||||
) -> ColoGraphModule:
|
||||
if is_compatible_with_meta():
|
||||
if meta_args is not None:
|
||||
root.to(default_device())
|
||||
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
|
||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
root.cpu()
|
||||
else:
|
||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||
else:
|
||||
from .tracer import ColoTracer as OrigColoTracer
|
||||
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||
concrete_args=concrete_args,
|
||||
meta_args=meta_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
||||
|
||||
|
@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
|
|||
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
|
||||
node.kwargs)
|
||||
|
||||
|
||||
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
|
||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||
if kind == 'placeholder':
|
||||
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
|
||||
_truncate_suffix(target), None)
|
||||
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
|
||||
elif kind == 'get_attr':
|
||||
attr_itr = root
|
||||
atoms = target.split(".")
|
||||
|
@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||
else:
|
||||
if target not in _TensorPropertyMethod:
|
||||
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
elif kind == 'call_module':
|
||||
mod = root.get_submodule(target)
|
||||
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
|
@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||
meta_out = None
|
||||
return meta_out
|
||||
|
||||
|
||||
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
||||
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
|
||||
meta_out = meta_args[target]
|
||||
|
@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
|||
return meta_out
|
||||
|
||||
|
||||
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
|
||||
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
|
||||
result_graph = Graph()
|
||||
value_remap = {}
|
||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||
|
@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||
if target == torch.nn.functional.linear:
|
||||
if 'bias' in kwargs and kwargs['bias'] is not None:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
else:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_method.has(method):
|
||||
function_to_substitute = method_to_func_dict[method]
|
||||
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
elif kind == "call_module":
|
||||
# if not hasattr(self, "orig_forward"):
|
||||
|
@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||
mod_type = type(mod)
|
||||
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
||||
function_to_substitute = module_to_func_dict[mod_type]
|
||||
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
||||
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
|
||||
function_to_substitute)
|
||||
|
||||
if handle is not None:
|
||||
handle.generate()
|
||||
for node_inserted in tracer.graph.nodes:
|
||||
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
|
||||
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
|
||||
last_node = value_remap[node_inserted]
|
||||
value_remap[orig_node] = last_node
|
||||
else:
|
||||
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
|
||||
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
|
||||
|
||||
del tracer
|
||||
|
||||
gm.graph = result_graph
|
||||
gm.recompile()
|
||||
meta_prop_pass(gm, root_model, meta_args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue