From 132b4306b7b38a9c488c6dd43ade41475951498e Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Tue, 4 Oct 2022 16:48:24 +0800 Subject: [PATCH] [fx] Add concrete info prop (#1677) * [fx] concreteinfoprop * [fx] add concreteinfoprop * [fx] modify docstring of ConcreteInfoProp * [fx] fix device error * [fx] modify parameter calculation * [fx] modify parameters calculation --- colossalai/fx/passes/__init__.py | 1 + colossalai/fx/passes/concrete_info_prop.py | 290 +++++++++++++++++++++ colossalai/fx/profiler/dataflow.py | 7 +- colossalai/fx/profiler/profiler.py | 154 +++++++++-- 4 files changed, 424 insertions(+), 28 deletions(-) create mode 100644 colossalai/fx/passes/concrete_info_prop.py diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py index aa6a7009c..43ac14ec4 100644 --- a/colossalai/fx/passes/__init__.py +++ b/colossalai/fx/passes/__init__.py @@ -1,3 +1,4 @@ from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass from .meta_info_prop import MetaInfoProp +from .concrete_info_prop import ConcreteInfoProp diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py new file mode 100644 index 000000000..44dea6fc4 --- /dev/null +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -0,0 +1,290 @@ +from dataclasses import asdict +from colossalai.fx.profiler import GraphInfo +import torch +import torch.fx +from torch.fx.node import Node, Argument, Target +from torch.utils._pytree import tree_flatten +from typing import Any, List, Tuple, NamedTuple, Dict, Optional +from torch.fx._compatibility import compatibility +from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size +from torch.fx.graph_module import GraphModule + + +@compatibility(is_backward_compatible=True) +class ConcreteInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node with concrete tensor and record the memory + usage, execution time of forward and backward, and type of the result into + the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_HIDDEN = 16 + DIM_OUT = 16 + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ).cuda() + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") + gm = symbolic_trace(model) + interp = ConcreteInfoProp(gm) + interp.run(input_sample) + print(interp.summary(unit='kb')) + + + output of above code is + Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- + placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB + call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB + output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Args: + module (GraphModule): The module to be executed + + """ + + _is_proped: bool = False + + def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: + """Customized run for ConcreteInfoProp + We need to store the device in self.device + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + + flatten_args, _ = tree_flatten(args) + self.device = next(item for item in flatten_args if hasattr(item, "device")).device + return super().run(*args, initial_env, enable_io_processing) + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + self._is_proped = True + result, meta_info = super().run_node(n) + + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) + n.meta['type'] = type(result) + + # retain the autograd graph + for param in self.module.parameters(): + param.grad = None + + return result + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return super().placeholder(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + return super().get_attr(target, args, kwargs), GraphInfo() + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + assert not isinstance(target, str) + return profile_function(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return profile_method(target, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod, self.device)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + meta_info (MetaInfo): The memory cost and forward & backward time. + """ + return args[0], GraphInfo(save_fwd_in=True) + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def time_repr(time: float): + return f"{time:,} s" + + for node in self.module.graph.nodes: + node: Node + node_summaries.append([ + node.op, + str(node), + time_repr(node.meta['fwd_time']), + time_repr(node.meta['bwd_time']), + node.meta['save_fwd_in'], + mem_repr(node.meta['fwd_mem_out']), + mem_repr(node.meta['fwd_mem_tmp']), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward time', + 'Backward time', + 'SAVE_FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right') diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 0551f6e25..14d876a78 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from enum import Enum +from functools import partial from typing import Dict from torch.fx import Graph, Node from .memory import activation_size, is_inplace @@ -33,8 +34,10 @@ class GraphInfo: ------------------------------- ============================================================================ Attributes: - fwd_flop (int): The forward FLOPs of a certain node + fwd_flop (int): The forward FLOPs of a certain node. + fwd_time (float): The real forward time (s) of a certain node. bwd_flop (int): The backward FLOPs of a certain node. + bwd_time (float): The real backward time (s) of a certain node. save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. fwd_mem_tmp (int): See the above illustration. fwd_mem_out (int): See the above illustration. @@ -42,7 +45,9 @@ class GraphInfo: bwd_mem_out (int): See the above illustration. """ fwd_flop: int = 0 + fwd_time: float = 0.0 bwd_flop: int = 0 + bwd_time: float = 0.0 save_fwd_in: bool = False fwd_mem_tmp: int = 0 fwd_mem_out: int = 0 diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 563a234d9..4b2874fdb 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -5,10 +5,11 @@ from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo -from .memory import activation_size +from .memory import activation_size, parameter_size from .constant import ALIAS_ATEN from .tensor import MetaTensor from .opcount import flop_mapping +import time __all__ = ['profile_function', 'profile_module', 'profile_method'] @@ -27,33 +28,112 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -# super-dainiu: -# x.detach() will change the unique identifier of data_ptr -# we need to handle this in a stupid way -def detach(x): +def detach_variables(x): if isinstance(x, torch.Tensor): requires_grad = x.requires_grad - x.requires_grad_(False) - x.requires_grad_(requires_grad) + x = x.detach() + x.requires_grad = requires_grad + + return x def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: - """ - Profile a Callable function with args and kwargs on concrete devices. + """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 + To profile the actual forward memory, we first run target in the context torch.no_grad() to get + the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory + by memory allocated minus the fwd_mem_out. + To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then + find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size + of args and kwargs). + We also add time stamps to profile the real forward and backward time. Args: target (Callable): A Callable function - args (Any): Argument - kwargs (Any): Argument - - Raises: - NotImplementedError: TODO(yby) + args (Any): Arguments + kwargs (Any): Arguments Returns: - out (Tuple[Any, ...]): The argument value that was retrieved. - meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. + Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward + time. """ - raise NotImplementedError + + graphinfo = GraphInfo() + + # detach input from the graph + args = tree_map(detach_variables, args) + kwargs = tree_map(detach_variables, kwargs) + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = getattr(self_obj, target)(*args_tail, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = getattr(self_obj, target)(*args_tail, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + else: + # calculate fwd_mem_out + mem_stamp0 = torch.cuda.memory_allocated() + with torch.no_grad(): + out = target(*args, **kwargs) + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 + del out + + # calculate fwd_mem_tmp & fwd_time + mem_stamp0 = torch.cuda.memory_allocated() + fwd_time0 = time.time() + out = target(*args, **kwargs) + fwd_time1 = time.time() + graphinfo.fwd_time = fwd_time1 - fwd_time0 + mem_stamp1 = torch.cuda.memory_allocated() + graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out + + # calculate bwd_mem_tmp & bwd_time + grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + bwd_time0 = time.time() + torch.autograd.backward(out, grad_tensors=grad_tensors) + bwd_time1 = time.time() + graphinfo.bwd_time = bwd_time1 - bwd_time0 + mem_stamp1 = torch.cuda.max_memory_allocated() + + # calculate bwd memory stats + # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation + graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) + graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 + graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out + + return tree_map(detach_variables, out), graphinfo def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: @@ -135,7 +215,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['saved_tensor'] = [] - detach(x) return x # Basically, we need to detach the args and kwargs from the outer graph. @@ -206,12 +285,26 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: kwargs['inplace'] = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) - else: - out, meta = _profile_concrete(func, *args, **kwargs) - if inplace: + + # currently we set the fwd_mem_tmp of ReLU to zero if target in [torch.nn.functional.relu]: meta.save_fwd_in = False meta.bwd_mem_out = 0 + meta.fwd_mem_tmp = 0 + else: + out, meta = _profile_concrete(func, *args, **kwargs) + + # find the grad for parameter in args and kwargs + param_size = 0 + + def get_param_size(x): + if isinstance(x, torch.nn.parameter): + param_size += activation_size(x) + + tree_map(get_param_size, args) + tree_map(get_param_size, kwargs) + + meta.bwd_mem_out -= param_size return out, meta f.__name__ = target.__name__ @@ -257,18 +350,25 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: # If there is an argument that this `call_module` is inplace, we should # still run the profiling but discard some results regarding `module`. inplace = getattr(module, 'inplace', False) + + # calculate parameter size + param_size = parameter_size(module) + if inplace: module.inplace = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) - else: - out, meta = _profile_concrete(func, *args, **kwargs) - if inplace: - # super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU` - # is the only inplace activation function that discard its input. - if type(module) in [torch.nn.ReLU]: + + # currently we set the fwd_mem_tmp of ReLU to zero + if type(module) in [torch.nn.modules.activation.ReLU]: meta.save_fwd_in = False meta.bwd_mem_out = 0 + meta.fwd_mem_tmp = 0 + else: + out, meta = _profile_concrete(func, *args, **kwargs) + + # grad for param will not be counted + meta.bwd_mem_out -= param_size return out, meta f.__name__ = module.__class__.__name__