mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactor tracer to trace complete graph (#1342)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] refactor tracer to trace complete graph
* add comments and solve conflicts.
pull/1344/head
parent
2cc1175c76
commit
942c8cd1fb
|
@ -2,6 +2,7 @@ import operator
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.proxy import Proxy, Attribute
|
from torch.fx.proxy import Proxy, Attribute
|
||||||
from typing import List, Union, Any
|
from typing import List, Union, Any
|
||||||
|
from colossalai.fx.tracer.meta_patch import meta_patched_function
|
||||||
|
|
||||||
__all__ = ['ColoProxy']
|
__all__ = ['ColoProxy']
|
||||||
|
|
||||||
|
@ -45,6 +46,14 @@ class ColoProxy(Proxy):
|
||||||
self._assert_has_meta_data()
|
self._assert_has_meta_data()
|
||||||
return len(self.meta_data)
|
return len(self.meta_data)
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
self._assert_has_meta_data()
|
||||||
|
return int(self.meta_data)
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
self._assert_has_meta_data()
|
||||||
|
return float(self.meta_data)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
self._assert_has_meta_data()
|
self._assert_has_meta_data()
|
||||||
return self.meta_data
|
return self.meta_data
|
||||||
|
@ -53,9 +62,6 @@ class ColoProxy(Proxy):
|
||||||
|
|
||||||
return ColoAttribute(self, k)
|
return ColoAttribute(self, k)
|
||||||
|
|
||||||
def __setitem__(self, indices, values):
|
|
||||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
if self.node.op == "placeholder":
|
if self.node.op == "placeholder":
|
||||||
# this is used to handle like
|
# this is used to handle like
|
||||||
|
@ -65,11 +71,26 @@ class ColoProxy(Proxy):
|
||||||
return super().__contains__(key)
|
return super().__contains__(key)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_meta(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
This function is copied from _tracer_utils.py to avoid circular import issue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert(val):
|
||||||
|
if isinstance(val, ColoProxy):
|
||||||
|
return val.meta_data
|
||||||
|
elif isinstance(val, (list, tuple)):
|
||||||
|
return type(val)([_convert(ele) for ele in val])
|
||||||
|
return val
|
||||||
|
|
||||||
|
new_args = [_convert(val) for val in args]
|
||||||
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||||
|
return new_args, new_kwargs
|
||||||
|
|
||||||
|
|
||||||
class ColoAttribute(ColoProxy):
|
class ColoAttribute(ColoProxy):
|
||||||
|
|
||||||
def __init__(self, root, attr: str):
|
def __init__(self, root, attr: str):
|
||||||
# this class is copied from torch.fx.Attribute
|
|
||||||
# but inherits ColoProxy
|
|
||||||
self.root = root
|
self.root = root
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
self.tracer = root.tracer
|
self.tracer = root.tracer
|
||||||
|
@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy):
|
||||||
@property
|
@property
|
||||||
def node(self):
|
def node(self):
|
||||||
if self._node is None:
|
if self._node is None:
|
||||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
|
||||||
|
if not isinstance(proxy, ColoProxy):
|
||||||
|
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
|
||||||
|
meta_out = getattr(*meta_args, **meta_kwargs)
|
||||||
|
proxy = ColoProxy(proxy.node)
|
||||||
|
proxy.meta_data = meta_out
|
||||||
|
self._node = proxy.node
|
||||||
|
|
||||||
return self._node
|
return self._node
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||||
|
if not isinstance(proxy, ColoProxy):
|
||||||
|
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
|
||||||
|
method = getattr(meta_args[0].__class__, self.attr)
|
||||||
|
if meta_patched_function.has(method):
|
||||||
|
meta_target = meta_patched_function.get(method)
|
||||||
|
elif meta_patched_function.has(target.__name__):
|
||||||
|
meta_target = meta_patched_function.get(target.__name__)
|
||||||
|
else:
|
||||||
|
meta_target = method
|
||||||
|
meta_out = meta_target(*meta_args, **meta_kwargs)
|
||||||
|
proxy = ColoProxy(proxy.node)
|
||||||
|
proxy.meta_data = meta_out
|
||||||
|
return proxy
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import List, Union, Any
|
from typing import List, Union, Any
|
||||||
from ..proxy import ColoProxy, ColoAttribute
|
from ..proxy import ColoProxy, ColoAttribute
|
||||||
|
import torch
|
||||||
|
from .meta_patch import meta_patched_function, meta_patched_module
|
||||||
|
|
||||||
__all__ = ['is_element_in_list', 'extract_meta']
|
__all__ = ['is_element_in_list', 'extract_meta']
|
||||||
|
|
||||||
|
@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs):
|
||||||
new_args = [_convert(val) for val in args]
|
new_args = [_convert(val) for val in args]
|
||||||
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||||
return new_args, new_kwargs
|
return new_args, new_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_meta_data_for_functions_proxy(target, args, kwargs):
|
||||||
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
||||||
|
|
||||||
|
# fetch patched function
|
||||||
|
if meta_patched_function.has(target):
|
||||||
|
meta_target = meta_patched_function.get(target)
|
||||||
|
elif meta_patched_function.has(target.__name__):
|
||||||
|
meta_target = meta_patched_function.get(target.__name__)
|
||||||
|
else:
|
||||||
|
meta_target = target
|
||||||
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
|
if isinstance(meta_out, torch.Tensor):
|
||||||
|
meta_out = meta_out.to(device="meta")
|
||||||
|
|
||||||
|
return meta_out
|
||||||
|
|
|
@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs):
|
||||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.finfo)
|
||||||
|
def torch_finfo(*args):
|
||||||
|
return torch.finfo(*args)
|
||||||
|
|
||||||
|
|
||||||
@meta_patched_function.register(torch.where)
|
@meta_patched_function.register(torch.where)
|
||||||
def torch_where(condition, x, y):
|
def torch_where(condition, x, y):
|
||||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||||
|
|
|
@ -7,6 +7,7 @@ tracer.py:
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
import functools
|
import functools
|
||||||
|
import operator
|
||||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -16,8 +17,9 @@ from torch.fx.graph import Graph
|
||||||
from torch.fx.proxy import Proxy, ParameterProxy
|
from torch.fx.proxy import Proxy, ParameterProxy
|
||||||
from ..proxy import ColoProxy
|
from ..proxy import ColoProxy
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from ._tracer_utils import is_element_in_list, extract_meta
|
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
|
||||||
from .meta_patch import meta_patched_function, meta_patched_module
|
from .meta_patch import meta_patched_function, meta_patched_module
|
||||||
|
from torch.fx.graph import magic_methods, reflectable_magic_methods
|
||||||
|
|
||||||
__all__ = ['ColoTracer']
|
__all__ = ['ColoTracer']
|
||||||
|
|
||||||
|
@ -61,7 +63,7 @@ class ColoTracer(Tracer):
|
||||||
# Feature flag for proxying accesses to buffer values
|
# Feature flag for proxying accesses to buffer values
|
||||||
proxy_buffer_attributes: bool = True
|
proxy_buffer_attributes: bool = True
|
||||||
|
|
||||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
|
||||||
|
|
||||||
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
|
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
|
||||||
"""
|
"""
|
||||||
|
@ -344,11 +346,15 @@ def wrap_tensor_constructor_method(target):
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, Proxy):
|
if isinstance(arg, Proxy):
|
||||||
return arg
|
return arg
|
||||||
|
if isinstance(arg, (tuple, list)):
|
||||||
|
return look_for_proxy(*arg)
|
||||||
|
|
||||||
# find in keyword vars
|
# find in keyword vars
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, Proxy):
|
if isinstance(v, Proxy):
|
||||||
return v
|
return v
|
||||||
|
if isinstance(v, (tuple, list)):
|
||||||
|
return look_for_proxy(*v)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@functools.wraps(target)
|
@functools.wraps(target)
|
||||||
|
@ -358,10 +364,60 @@ def wrap_tensor_constructor_method(target):
|
||||||
if proxy is not None:
|
if proxy is not None:
|
||||||
# if the arg is a proxy, then need to record this function called on this proxy
|
# if the arg is a proxy, then need to record this function called on this proxy
|
||||||
# e.g. torch.ones(size) where size is an input proxy
|
# e.g. torch.ones(size) where size is an input proxy
|
||||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||||
|
if not isinstance(colo_proxy, ColoProxy):
|
||||||
|
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||||
|
colo_proxy = ColoProxy(fx_proxy.node)
|
||||||
|
colo_proxy.meta_data = meta_out
|
||||||
|
return colo_proxy
|
||||||
else:
|
else:
|
||||||
# this is called directly when the inputs do not contain proxy
|
# this is called directly when the inputs do not contain proxy
|
||||||
# e.g. torch.ones(4) where the input is static
|
# e.g. torch.ones(4) where the input is static
|
||||||
return target(*args, **kwargs)
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper, target
|
return wrapper, target
|
||||||
|
|
||||||
|
|
||||||
|
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
|
||||||
|
# and add meta_data attribute to the created proxy.
|
||||||
|
for method in magic_methods:
|
||||||
|
|
||||||
|
def _scope(method):
|
||||||
|
|
||||||
|
def impl(*args, **kwargs):
|
||||||
|
|
||||||
|
tracer = args[0].tracer
|
||||||
|
target = getattr(operator, method)
|
||||||
|
proxy = tracer.create_proxy('call_function', target, args, kwargs)
|
||||||
|
if not isinstance(proxy, ColoProxy):
|
||||||
|
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||||
|
proxy = ColoProxy(proxy.node)
|
||||||
|
proxy.meta_data = meta_out
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
impl.__name__ = method
|
||||||
|
as_magic = f'__{method.strip("_")}__'
|
||||||
|
setattr(ColoProxy, as_magic, impl)
|
||||||
|
|
||||||
|
_scope(method)
|
||||||
|
|
||||||
|
|
||||||
|
def _define_reflectable(orig_method_name):
|
||||||
|
method_name = f'__r{orig_method_name.strip("_")}__'
|
||||||
|
|
||||||
|
def impl(self, rhs):
|
||||||
|
target = getattr(operator, orig_method_name)
|
||||||
|
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
|
||||||
|
if not isinstance(proxy, ColoProxy):
|
||||||
|
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
|
||||||
|
proxy = ColoProxy(proxy.node)
|
||||||
|
proxy.meta_data = meta_out
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
impl.__name__ = method_name
|
||||||
|
impl.__qualname__ = method_name
|
||||||
|
setattr(ColoProxy, method_name, impl)
|
||||||
|
|
||||||
|
|
||||||
|
for orig_method_name in reflectable_magic_methods:
|
||||||
|
_define_reflectable(orig_method_name)
|
||||||
|
|
|
@ -1,17 +1,40 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from colossalai.fx.proxy import ColoProxy
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from torch.fx import GraphModule
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip due to tracer')
|
class Conv1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, nf, nx):
|
||||||
|
super().__init__()
|
||||||
|
self.nf = nf
|
||||||
|
w = torch.empty(nx, nf)
|
||||||
|
nn.init.normal_(w, std=0.02)
|
||||||
|
self.weight = nn.Parameter(w)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(nf))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size_out = x.shape[:-1] + (self.nf,)
|
||||||
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||||
|
x = x.view(size_out)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def test_coloproxy():
|
def test_coloproxy():
|
||||||
# create a dummy node only for testing purpose
|
|
||||||
model = torch.nn.Linear(10, 10)
|
tracer = ColoTracer()
|
||||||
gm = torch.fx.symbolic_trace(model)
|
model = Conv1D(3, 3)
|
||||||
|
input_sample = {'x': torch.rand(3, 3).to('meta')}
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
node = list(gm.graph.nodes)[0]
|
node = list(gm.graph.nodes)[0]
|
||||||
|
|
||||||
# create proxy
|
proxy = ColoProxy(node=node, tracer=tracer)
|
||||||
proxy = ColoProxy(node=node)
|
|
||||||
proxy.meta_data = torch.empty(4, 2, device='meta')
|
proxy.meta_data = torch.empty(4, 2, device='meta')
|
||||||
|
|
||||||
assert len(proxy) == 4
|
assert len(proxy) == 4
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip due to tracer')
|
|
||||||
def test_opt():
|
def test_opt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.OPTModel,
|
transformers.OPTModel,
|
||||||
|
|
|
@ -24,7 +24,6 @@ def test_timm_models_without_control_flow():
|
||||||
split_model_and_compare_output(model, data)
|
split_model_and_compare_output(model, data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip due to tracer')
|
|
||||||
def test_timm_models_with_control_flow():
|
def test_timm_models_with_control_flow():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip due to tracer')
|
|
||||||
def test_opt():
|
def test_opt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.OPTModel,
|
transformers.OPTModel,
|
||||||
|
|
|
@ -54,7 +54,6 @@ def test_timm_models_without_control_flow():
|
||||||
trace_and_compare(model_cls, tracer, data)
|
trace_and_compare(model_cls, tracer, data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip due to tracer')
|
|
||||||
def test_timm_models_with_control_flow():
|
def test_timm_models_with_control_flow():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue