From 3a02b464479b5da51bab4fa5b65cc7066dece229 Mon Sep 17 00:00:00 2001 From: Zihao <804673818@qq.com> Date: Wed, 4 Jan 2023 14:44:22 +0800 Subject: [PATCH] [auto-parallel] refactoring ColoTracer (#2118) * add meta_data_computing * add checkpoint_annotation * rename proxy.data to proxy.meta_data and add bias addition pass * polish code * delete meta_prop_pass invoke and rename ori_node to orig_node * add TracerType * unify meta data computing * delete TracerType * handle setitem operation * operator.setitem --- colossalai/fx/tracer/experimental.py | 340 +++++++++++++++++++++++---- 1 file changed, 294 insertions(+), 46 deletions(-) diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py index 66e714912..6fee5f5d0 100644 --- a/colossalai/fx/tracer/experimental.py +++ b/colossalai/fx/tracer/experimental.py @@ -1,6 +1,8 @@ import enum import functools +import operator import inspect +from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch @@ -8,6 +10,15 @@ from torch.fx import Graph, Node, Proxy, Tracer from torch.utils._pytree import tree_map from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta +from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict +from colossalai.fx.tracer.registry import ( + bias_addition_function, + bias_addition_method, + bias_addition_module, + meta_patched_function, + meta_patched_module, +) if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -31,18 +42,6 @@ def _truncate_suffix(s: str): return re.sub(r'_\d+$', '', s) -def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): - if isinstance(elements, (tuple, list, set)): - for ele in elements: - if ele not in list_: - return False, ele - else: - if elements not in list_: - return False, elements - - return True, None - - def default_device(): return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') @@ -52,24 +51,24 @@ class ColoProxy(Proxy): def __init__(self, *args, data=None, **kwargs): super().__init__(*args, **kwargs) - self._data = data + self._meta_data = data @property - def data(self): - return self._data + def meta_data(self): + return self._meta_data - @data.setter - def data(self, args): + @meta_data.setter + def meta_data(self, args): wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x - self._data = tree_map(wrap_fn, args) + self._meta_data = tree_map(wrap_fn, args) @classmethod def __torch_function__(cls, orig_method, types, args=(), kwargs=None): proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) - unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p kwargs = {} if kwargs is None else kwargs - if proxy.data is None: - proxy.data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + if proxy.meta_data is None: + proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) return proxy @classmethod @@ -77,28 +76,33 @@ class ColoProxy(Proxy): return cls(proxy.node, proxy.tracer) def __repr__(self): - return f"ColoProxy({self.node.name}, data={self.data})" + return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" def __len__(self): - return len(self.data) + return len(self.meta_data) def __int__(self): - return int(self.data) + return int(self.meta_data) def __index__(self): try: - return int(self.data) + return int(self.meta_data) except: - return torch.zeros(self.data.shape, dtype=torch.bool).numpy().__index__() + return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() def __float__(self): - return float(self.data) + return float(self.meta_data) def __bool__(self): - return self.data + return self.meta_data def __getattr__(self, k): - return ColoAttribute(self, k, getattr(self._data, k, None)) + return ColoAttribute(self, k, getattr(self._meta_data, k, None)) + + def __setitem__(self, key, value): + proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy.meta_data = self._meta_data + return proxy def __contains__(self, key): if self.node.op == "placeholder": @@ -109,26 +113,26 @@ class ColoProxy(Proxy): return super().__contains__(key) def __isinstancecheck__(self, type): - return isinstance(self.data, type) + return isinstance(self.meta_data, type) @property def shape(self): - return self.data.shape + return self.meta_data.shape @property def ndim(self): - return self.data.ndim + return self.meta_data.ndim @property def device(self): proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) - proxy.data = self.data.device + proxy.meta_data = self.meta_data.device return proxy @property def dtype(self): proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) - proxy.data = self.data.dtype + proxy.meta_data = self.meta_data.dtype return proxy def to(self, *args, **kwargs): @@ -148,7 +152,7 @@ class ColoAttribute(ColoProxy): self.root = root self.attr = attr self.tracer = root.tracer - self._data = data + self._meta_data = data self._node: Optional[Node] = None @property @@ -174,6 +178,12 @@ class ColoTracer(Tracer): self._disable_module_getattr = False self.proxy_buffer_attributes = True + # whether the tracer will record the usage of torch.utils.checkpoint + self.trace_act_ckpt = trace_act_ckpt + # whether the current tracing occurs within the activation checkpoint functions + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count = 0 + def proxy(self, node: Node) -> 'ColoProxy': return ColoProxy(node, self) @@ -185,10 +195,11 @@ class ColoTracer(Tracer): name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) - unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p if kind == 'placeholder': - proxy.data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( + proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( _truncate_suffix(target), None) elif kind == 'get_attr': self._disable_module_getattr = True @@ -197,32 +208,39 @@ class ColoTracer(Tracer): atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) - proxy.data = attr_itr + proxy.meta_data = attr_itr finally: self._disable_module_getattr = False elif kind == 'call_function': - proxy.data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) elif kind == 'call_method': self._disable_module_getattr = True try: if target == '__call__': - proxy.data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) finally: self._disable_module_getattr = False elif kind == 'call_module': mod = self.root.get_submodule(target) - unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p self._disable_module_getattr = True try: - proxy.data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + proxy.meta_data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) finally: - self._disable_module_getattr = True + self._disable_module_getattr = False return proxy + def create_node(self, *args, **kwargs) -> Node: + node = super().create_node(*args, **kwargs) + + if self.inside_torch_checkpoint_func: + # annotate the activation checkpoint module + node.meta['activation_checkpoint'] = self.act_ckpt_region_count + return node + def trace(self, root: torch.nn.Module, concrete_args: Optional[Dict[str, torch.Tensor]] = None, @@ -263,11 +281,42 @@ class ColoTracer(Tracer): self.concrete_args = concrete_args self.meta_args = meta_args - with _TorchTensorOverride(self): + with _TorchTensorOverride(self), self.trace_activation_checkpoint(enabled=self.trace_act_ckpt): self.graph = super().trace(root, concrete_args=concrete_args) self.graph.lint() return self.graph + + @contextmanager + def trace_activation_checkpoint(self, enabled: bool): + if enabled: + orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction + + class PatchedCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # signal that the current tracing occurs within activaton checkpoint part + self.inside_torch_checkpoint_func = True + out = run_function(*args) + self.inside_torch_checkpoint_func = False + self.act_ckpt_region_count += 1 + return out + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError( + "We do not implement the backward pass as we only trace the forward pass.") + + # override the checkpoint function + torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction + yield + + if enabled: + # recover the checkpoint function upon exit + torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func + + def _post_check(self, non_concrete_arg_names: Set[str]): # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. @@ -392,3 +441,202 @@ class _TorchTensorOverride(object): def __exit__(self, exc_type, exc_val, exc_tb): for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, orig) + + +def meta_prop_pass(gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None): + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} + + # check concrete and meta args have valid names + sig = inspect.signature(root.forward) + sig_names = set(sig.parameters.keys()) + meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + for node in gm.graph.nodes: + node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, + node.kwargs) + +def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + if kind == 'placeholder': + meta_out = meta_args[target] if target in meta_args else concrete_args.get( + _truncate_suffix(target), None) + elif kind == 'get_attr': + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + meta_out = attr_itr + elif kind == 'call_function': + meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_method': + if target == '__call__': + meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) + else: + if target not in _TensorPropertyMethod: + meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), + **tree_map(unwrap_fn, kwargs)) + elif kind == 'call_module': + mod = root.get_submodule(target) + meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + else: + meta_out = None + return meta_out + +def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): + if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: + meta_out = meta_args[target] + return meta_out + + if target in [getattr(torch, torch_func) for torch_func in _TorchNewMethod]: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + args_metas = tree_map(unwrap_fn, args) + kwargs_metas = tree_map(unwrap_fn, kwargs) + + if kind == "call_function": + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) + else: + meta_target = target + + meta_out = meta_target(*args_metas, **kwargs_metas) + + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + + # fetch patched method + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_module": + mod = root.get_submodule(target) + mod_type = type(mod) + if meta_patched_module.has(mod_type): + meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas) + else: + meta_out = mod(*args_metas, **kwargs_metas) + elif kind == "get_attr": + attr_itr = root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr + else: + return None + + except Exception as e: + raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") + + return meta_out + + +def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None): + result_graph = Graph() + value_remap = {} + unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n + + for orig_node in gm.graph.nodes: + assert hasattr(orig_node, "_meta_data") + kind = orig_node.op + target = orig_node.target + args = orig_node.args + kwargs = orig_node.kwargs + + args_metas = tree_map(unwrap_fn, args) + tracer = ColoTracer() + tracer.graph = Graph(tracer_cls=ColoTracer) + tracer.root = root_model + + def wrap_fn(n): + if isinstance(n, Node): + proxy = ColoProxy(n, tracer) + proxy.meta_data = n._meta_data + return proxy + return n + + args_proxy = tree_map(wrap_fn, args) + kwargs_proxy = tree_map(wrap_fn, kwargs) + + handle = None + if kind == "call_function": + if bias_addition_function.has(target): + if target == torch.nn.functional.linear: + if 'bias' in kwargs and kwargs['bias'] is not None: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + else: + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + elif bias_addition_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + function_to_substitute = func_to_func_dict[target] + handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + elif kind == "call_method": + method = getattr(args_metas[0].__class__, target) + if bias_addition_method.has(method): + function_to_substitute = method_to_func_dict[method] + handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + elif kind == "call_module": + # if not hasattr(self, "orig_forward"): + # raise AttributeError(f"{self} does not have an attribute called orig_forward") + mod = gm.get_submodule(target) + mod_type = type(mod) + if bias_addition_module.has(mod_type) and mod.bias is not None: + function_to_substitute = module_to_func_dict[mod_type] + handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + + if handle is not None: + handle.generate() + for node_inserted in tracer.graph.nodes: + value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n]) + last_node = value_remap[node_inserted] + value_remap[orig_node] = last_node + else: + value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n]) + + del tracer + + gm.graph = result_graph + gm.recompile() + meta_prop_pass(gm, root_model, meta_args) +