From 942c8cd1fb509d32fc741d076b6ed0d80b03902b Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 20 Jul 2022 11:20:38 +0800 Subject: [PATCH] [fx] refactor tracer to trace complete graph (#1342) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [fx] refactor tracer to trace complete graph * add comments and solve conflicts. --- colossalai/fx/proxy.py | 55 +++++++++++++--- colossalai/fx/tracer/_tracer_utils.py | 19 ++++++ .../meta_patch/patched_function/torch_ops.py | 5 ++ colossalai/fx/tracer/tracer.py | 62 ++++++++++++++++++- tests/test_fx/test_coloproxy.py | 35 +++++++++-- .../test_pipeline/test_hf_model/test_opt.py | 1 - .../test_timm_model/test_timm.py | 1 - .../test_tracer/test_hf_model/test_hf_opt.py | 1 - .../test_timm_model/test_timm_model.py | 1 - 9 files changed, 160 insertions(+), 20 deletions(-) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index e96971b36..464c3598f 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -2,6 +2,7 @@ import operator import torch from torch.fx.proxy import Proxy, Attribute from typing import List, Union, Any +from colossalai.fx.tracer.meta_patch import meta_patched_function __all__ = ['ColoProxy'] @@ -45,6 +46,14 @@ class ColoProxy(Proxy): self._assert_has_meta_data() return len(self.meta_data) + def __int__(self): + self._assert_has_meta_data() + return int(self.meta_data) + + def __float__(self): + self._assert_has_meta_data() + return float(self.meta_data) + def __bool__(self): self._assert_has_meta_data() return self.meta_data @@ -53,9 +62,6 @@ class ColoProxy(Proxy): return ColoAttribute(self, k) - def __setitem__(self, indices, values): - return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) - def __contains__(self, key): if self.node.op == "placeholder": # this is used to handle like @@ -65,11 +71,26 @@ class ColoProxy(Proxy): return super().__contains__(key) +def extract_meta(*args, **kwargs): + """ + This function is copied from _tracer_utils.py to avoid circular import issue. + """ + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + class ColoAttribute(ColoProxy): def __init__(self, root, attr: str): - # this class is copied from torch.fx.Attribute - # but inherits ColoProxy self.root = root self.attr = attr self.tracer = root.tracer @@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy): @property def node(self): if self._node is None: - self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node + proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*(self.root, self.attr)) + meta_out = getattr(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + self._node = proxy.node + return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs) + method = getattr(meta_args[0].__class__, self.attr) + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = method + meta_out = meta_target(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 300a82276..0ec49a90a 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -1,5 +1,7 @@ from typing import List, Union, Any from ..proxy import ColoProxy, ColoAttribute +import torch +from .meta_patch import meta_patched_function, meta_patched_module __all__ = ['is_element_in_list', 'extract_meta'] @@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs): new_args = [_convert(val) for val in args] new_kwargs = {k: _convert(v) for k, v in kwargs.items()} return new_args, new_kwargs + + +def compute_meta_data_for_functions_proxy(target, args, kwargs): + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + + return meta_out diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index e3342a646..4c5c7c2e3 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs): return torch.empty((end - start) // step, dtype=dtype, device="meta") +@meta_patched_function.register(torch.finfo) +def torch_finfo(*args): + return torch.finfo(*args) + + @meta_patched_function.register(torch.where) def torch_where(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 5b7a1eced..1415e2f9d 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -7,6 +7,7 @@ tracer.py: import enum import inspect import functools +import operator from colossalai.fx.tracer.meta_patch import meta_patched_module import torch import torch.nn as nn @@ -16,8 +17,9 @@ from torch.fx.graph import Graph from torch.fx.proxy import Proxy, ParameterProxy from ..proxy import ColoProxy from typing import Optional, Dict, Any -from ._tracer_utils import is_element_in_list, extract_meta +from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy from .meta_patch import meta_patched_function, meta_patched_module +from torch.fx.graph import magic_methods, reflectable_magic_methods __all__ = ['ColoTracer'] @@ -61,7 +63,7 @@ class ColoTracer(Tracer): # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True - _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"] def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: """ @@ -344,11 +346,15 @@ def wrap_tensor_constructor_method(target): for arg in args: if isinstance(arg, Proxy): return arg + if isinstance(arg, (tuple, list)): + return look_for_proxy(*arg) # find in keyword vars for k, v in kwargs.items(): if isinstance(v, Proxy): return v + if isinstance(v, (tuple, list)): + return look_for_proxy(*v) return None @functools.wraps(target) @@ -358,10 +364,60 @@ def wrap_tensor_constructor_method(target): if proxy is not None: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy - return proxy.tracer.create_proxy("call_function", target, args, kwargs) + colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs) + if not isinstance(colo_proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + colo_proxy = ColoProxy(fx_proxy.node) + colo_proxy.meta_data = meta_out + return colo_proxy else: # this is called directly when the inputs do not contain proxy # e.g. torch.ones(4) where the input is static return target(*args, **kwargs) return wrapper, target + + +# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__, +# and add meta_data attribute to the created proxy. +for method in magic_methods: + + def _scope(method): + + def impl(*args, **kwargs): + + tracer = args[0].tracer + target = getattr(operator, method) + proxy = tracer.create_proxy('call_function', target, args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(ColoProxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(ColoProxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 82be9329d..2bb6cf864 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,17 +1,40 @@ import torch +import torch.nn as nn from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from torch.fx import GraphModule import pytest -@pytest.mark.skip('skip due to tracer') +class Conv1D(nn.Module): + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.shape[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + def test_coloproxy(): - # create a dummy node only for testing purpose - model = torch.nn.Linear(10, 10) - gm = torch.fx.symbolic_trace(model) + + tracer = ColoTracer() + model = Conv1D(3, 3) + input_sample = {'x': torch.rand(3, 3).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() node = list(gm.graph.nodes)[0] - # create proxy - proxy = ColoProxy(node=node) + proxy = ColoProxy(node=node, tracer=tracer) proxy.meta_data = torch.empty(4, 2, device='meta') assert len(proxy) == 4 diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index bd1b2aa2c..a55ea54fe 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('skip due to tracer') def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 81ff4536d..7c3764f34 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -24,7 +24,6 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('skip due to tracer') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 3206dc75b..5ac051887 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('skip due to tracer') def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 38f5a3829..2ee498b9e 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -54,7 +54,6 @@ def test_timm_models_without_control_flow(): trace_and_compare(model_cls, tracer, data) -@pytest.mark.skip('skip due to tracer') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True