[fx] refactor tracer to trace complete graph (#1342)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] refactor tracer to trace complete graph

* add comments and solve conflicts.
pull/1344/head
YuliangLiu0306 2022-07-20 11:20:38 +08:00 committed by GitHub
parent 2cc1175c76
commit 942c8cd1fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 160 additions and 20 deletions

View File

@ -2,6 +2,7 @@ import operator
import torch import torch
from torch.fx.proxy import Proxy, Attribute from torch.fx.proxy import Proxy, Attribute
from typing import List, Union, Any from typing import List, Union, Any
from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy'] __all__ = ['ColoProxy']
@ -45,6 +46,14 @@ class ColoProxy(Proxy):
self._assert_has_meta_data() self._assert_has_meta_data()
return len(self.meta_data) return len(self.meta_data)
def __int__(self):
self._assert_has_meta_data()
return int(self.meta_data)
def __float__(self):
self._assert_has_meta_data()
return float(self.meta_data)
def __bool__(self): def __bool__(self):
self._assert_has_meta_data() self._assert_has_meta_data()
return self.meta_data return self.meta_data
@ -53,9 +62,6 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k) return ColoAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key): def __contains__(self, key):
if self.node.op == "placeholder": if self.node.op == "placeholder":
# this is used to handle like # this is used to handle like
@ -65,11 +71,26 @@ class ColoProxy(Proxy):
return super().__contains__(key) return super().__contains__(key)
def extract_meta(*args, **kwargs):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
class ColoAttribute(ColoProxy): class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str): def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root self.root = root
self.attr = attr self.attr = attr
self.tracer = root.tracer self.tracer = root.tracer
@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy):
@property @property
def node(self): def node(self):
if self._node is None: if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
meta_out = getattr(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
self._node = proxy.node
return self._node return self._node
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
method = getattr(meta_args[0].__class__, self.attr)
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = method
meta_out = meta_target(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy

View File

@ -1,5 +1,7 @@
from typing import List, Union, Any from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['is_element_in_list', 'extract_meta'] __all__ = ['is_element_in_list', 'extract_meta']
@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs):
new_args = [_convert(val) for val in args] new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()} new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs return new_args, new_kwargs
def compute_meta_data_for_functions_proxy(target, args, kwargs):
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
return meta_out

View File

@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta") return torch.empty((end - start) // step, dtype=dtype, device="meta")
@meta_patched_function.register(torch.finfo)
def torch_finfo(*args):
return torch.finfo(*args)
@meta_patched_function.register(torch.where) @meta_patched_function.register(torch.where)
def torch_where(condition, x, y): def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y, # torch.where returns the broadcasted tensor of condition, x, and y,

View File

@ -7,6 +7,7 @@ tracer.py:
import enum import enum
import inspect import inspect
import functools import functools
import operator
from colossalai.fx.tracer.meta_patch import meta_patched_module from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,8 +17,9 @@ from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy from torch.fx.proxy import Proxy, ParameterProxy
from ..proxy import ColoProxy from ..proxy import ColoProxy
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from .meta_patch import meta_patched_function, meta_patched_module from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
__all__ = ['ColoTracer'] __all__ = ['ColoTracer']
@ -61,7 +63,7 @@ class ColoTracer(Tracer):
# Feature flag for proxying accesses to buffer values # Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True proxy_buffer_attributes: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
""" """
@ -344,11 +346,15 @@ def wrap_tensor_constructor_method(target):
for arg in args: for arg in args:
if isinstance(arg, Proxy): if isinstance(arg, Proxy):
return arg return arg
if isinstance(arg, (tuple, list)):
return look_for_proxy(*arg)
# find in keyword vars # find in keyword vars
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, Proxy): if isinstance(v, Proxy):
return v return v
if isinstance(v, (tuple, list)):
return look_for_proxy(*v)
return None return None
@functools.wraps(target) @functools.wraps(target)
@ -358,10 +364,60 @@ def wrap_tensor_constructor_method(target):
if proxy is not None: if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy # if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy # e.g. torch.ones(size) where size is an input proxy
return proxy.tracer.create_proxy("call_function", target, args, kwargs) colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(fx_proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else: else:
# this is called directly when the inputs do not contain proxy # this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static # e.g. torch.ones(4) where the input is static
return target(*args, **kwargs) return target(*args, **kwargs)
return wrapper, target return wrapper, target
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(ColoProxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(ColoProxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)

View File

@ -1,17 +1,40 @@
import torch import torch
import torch.nn as nn
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from torch.fx import GraphModule
import pytest import pytest
@pytest.mark.skip('skip due to tracer') class Conv1D(nn.Module):
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.shape[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def test_coloproxy(): def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10) tracer = ColoTracer()
gm = torch.fx.symbolic_trace(model) model = Conv1D(3, 3)
input_sample = {'x': torch.rand(3, 3).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
node = list(gm.graph.nodes)[0] node = list(gm.graph.nodes)[0]
# create proxy proxy = ColoProxy(node=node, tracer=tracer)
proxy = ColoProxy(node=node)
proxy.meta_data = torch.empty(4, 2, device='meta') proxy.meta_data = torch.empty(4, 2, device='meta')
assert len(proxy) == 4 assert len(proxy) == 4

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip('skip due to tracer')
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
transformers.OPTModel, transformers.OPTModel,

View File

@ -24,7 +24,6 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output(model, data) split_model_and_compare_output(model, data)
@pytest.mark.skip('skip due to tracer')
def test_timm_models_with_control_flow(): def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip('skip due to tracer')
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
transformers.OPTModel, transformers.OPTModel,

View File

@ -54,7 +54,6 @@ def test_timm_models_without_control_flow():
trace_and_compare(model_cls, tracer, data) trace_and_compare(model_cls, tracer, data)
@pytest.mark.skip('skip due to tracer')
def test_timm_models_with_control_flow(): def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True