From cd5cf2bcc90c585912e917a944cfcaba17c6b45c Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Thu, 15 Sep 2022 14:46:36 +0800 Subject: [PATCH] [fx/tuning] tune performance on rotor with meta info. (#1599) --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 83 ++++--------------- colossalai/fx/passes/algorithms/linearize.py | 6 +- colossalai/fx/profiler/__init__.py | 2 +- colossalai/fx/profiler/dataflow.py | 28 ++++--- colossalai/fx/profiler/memory.py | 36 +++++++- colossalai/fx/profiler/opcount.py | 1 + colossalai/fx/profiler/profiler.py | 47 ++++++----- 7 files changed, 96 insertions(+), 107 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 9cb48828e..d0928b405 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,8 +1,7 @@ from typing import List, Tuple -import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import parameter_size +from colossalai.fx.profiler import activation_size, parameter_size import math from .linearize import linearize from .utils import * @@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: # Build table opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] - ## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation # Initialize borders of the tables for lmax-lmin = 0 for m in range(mmax + 1): @@ -115,43 +114,6 @@ def _discretize(mem_unit, values): return [math.ceil(value / mem_unit) for value in values] -def _compute_size(obj: torch.Tensor) -> int: - return obj.numel() * obj.element_size() - - -def _compute_output_size(node: List[Node]) -> int: - """Compute the output size of a node - - Args: - node (List[Node]): node, list of torch.fx.Node - - Returns: - int: output size - """ - - return node[-1].meta['tensor_meta'].numel * torch.tensor([], - dtype=node[-1].meta['tensor_meta'].dtype).element_size() - - -def _get_inplace(node: Node) -> bool: - """Get the inplace argument from torch.fx.Node - - Args: - node (Node): torch.fx.Node - - Returns: - bool: indicates whether this op is inplace - """ - - is_inplace = False - if node.op == "call_function": - is_inplace = node.kwargs.get("inplace", False) - elif node.op == "call_module": - is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False) - - return is_inplace - - def _fwd_xbar(node: List[Node]) -> int: """Get the forward xbar of a node @@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int: for k, v in deps.items(): if v > 0: deps_size += k.meta['bwd_mem_out'] + if v == float('-inf'): + deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out'] return deps_size bwd_mem_tmp = 0 deps = {} - # add all the users for last node into deps, - # as those nodes' gradient out will be stored in memory - for child in node[-1].users: - deps[child] = 1 for n in reversed(node): + deps[n] = len(n.all_input_nodes) bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) - deps[n] = len(n.all_input_nodes) for child in n.users: if child in deps: deps[child] -= 1 - - for key in list(deps.keys()): - if deps[key] == 0: - del deps[key] + if deps[child] <= 0: + deps[child] = float('-inf') # free return bwd_mem_tmp -def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: +def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain: fwd_time = [] bwd_time = [] - - if isinstance(data, torch.Tensor): - xbar_sizes = [_compute_size(data)] - x_sizes = [_compute_size(data)] - elif isinstance(data, list) or isinstance(data, tuple): - xbar_sizes = [sum([_compute_size(obj) for obj in data])] - x_sizes = [sum([_compute_size(obj) for obj in data])] - elif isinstance(data, dict): - xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])] - x_sizes = [sum([_compute_size(obj) for obj in data.values()])] - + xbar_sizes = [activation_size(input)] + x_sizes = [activation_size(input)] # currently we can't get the temp memory needed in fwd tmp_fwd = [0] * len(node_list) tmp_bwd = [] @@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: for idx, node in enumerate(node_list): fwd_time.append(_fwd_time(node)) bwd_time.append(_bwd_time(node)) - x_sizes.append(_compute_output_size(node)) + x_sizes.append(node[-1].meta['fwd_mem_out']) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) tmp_bwd.append(_get_bwd_mem_tmp(node)) - # if a node with only one inplace op, we need to let x_bar = 0 - if len(node) == 1 and _get_inplace(node[0]): - xbar_sizes[-1] = 0 - bwd_time.append(0) # currently we view loss backward temp as zero @@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule, mem_limit: int, mem_slots: int = 500, cnode: List[str] = None, - eps: float = 0.02) -> ColoGraphModule: + eps: float = 0.0) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: @@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule, mem_limit (int): memory budget in Byte. mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. cnode (List[Node], optional): common node list for linearize. Defaults to None. - eps (float): epsilon for memory decay. Defaults to 0.02 + eps (float): epsilon for memory decay. Defaults to 0.0 Returns: ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py index 043827a76..1a49364f5 100644 --- a/colossalai/fx/passes/algorithms/linearize.py +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -1,5 +1,6 @@ from typing import List, Any from torch.fx import GraphModule, Node +from colossalai.fx.profiler import is_inplace # Common nodes are type of nodes that could be seen as attributes and remain # unchanged throughout the whole model, it will be used several times by @@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: Returns: List[List[Node]]: List of list, each inside list of Node presents the actual 'node' in linearized manner. + + Remarks: + We merge the inplace ops into the previous node. """ def _is_sink() -> bool: @@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: bool """ - return not sum([v for _, v in deps.items()]) + return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) # make sure that item in cnode is valid if cnode: diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index fb19618b2..be37fea70 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -7,4 +7,4 @@ else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module from .dataflow import GraphInfo -from .memory import parameter_size, activation_size +from .memory import parameter_size, activation_size, is_inplace diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 69319b792..2b4b6c17e 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,16 +1,17 @@ from dataclasses import dataclass from enum import Enum -from functools import partial from typing import Dict from torch.fx import Graph, Node -from .memory import activation_size +from .memory import activation_size, is_inplace +from . import META_COMPATIBILITY +if META_COMPATIBILITY: + from .memory import NORMALIZATION_ATEN, CLONE_ATEN class Phase(Enum): FORWARD = 0 - LOSS = 1 - BACKWARD = 2 - PLACEHOLDER = 3 + BACKWARD = 1 + PLACEHOLDER = 2 @dataclass @@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: def _peak_memory(deps: Dict[Node, int]): peak_mem = 0 for k, v in deps.items(): - if v > 0: + if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)): peak_mem += activation_size(k.meta['out']) + if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN): + peak_mem -= activation_size(k.meta['out']) return peak_mem # deps is used to track all the memory dependencies of the graph. @@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: for n in graph.nodes: n: Node - if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)): + if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)): # A forward tensor who is marked `save` but is not # an input to `loss` should be saved during forward. # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. @@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: graph_info.fwd_mem_tmp += activation_size(n.meta['out']) elif is_phase(n, Phase.BACKWARD): if len(n.users): - # liveness analysis is only used in backward - deps[n] = len(n.users) graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) - for input_n in n.all_input_nodes: - if input_n in deps: - deps[input_n] -= 1 else: + # TODO: some of the bwd_mem_out might be model parameters. # basically a backward node without user is a `grad_out` node graph_info.bwd_mem_out += activation_size(n.meta['out']) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + if deps[input_n] <= 0: + deps[input_n] = float('-inf') return graph_info diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index c023d0d1e..2e0f1a058 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,9 +1,10 @@ import torch +from torch.fx import Node from typing import Union, Dict, List, Tuple from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos from . import META_COMPATIBILITY -__all__ = ['activation_size', 'parameter_size'] +__all__ = ['activation_size', 'parameter_size', 'is_inplace'] if META_COMPATIBILITY: aten = torch.ops.aten @@ -21,6 +22,7 @@ if META_COMPATIBILITY: aten.bernoulli_.float, # inplace reshaping + aten.copy_.default, aten.detach.default, aten.t.default, aten.transpose.int, @@ -28,7 +30,17 @@ if META_COMPATIBILITY: aten._unsafe_view.default, ] - __all__ += ['INPLACE_ATEN', 'WEIRD_OPS'] + NORMALIZATION_ATEN = [ + aten.native_batch_norm.default, + aten.native_layer_norm.default, + # aten.max_pool2d_with_indices.default, + ] + + CLONE_ATEN = [ + aten.clone.default, + ] + + __all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN'] else: # TODO fill out the inplace ops @@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int: for param in mod.parameters(): param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() return param_size + + +def is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + if META_COMPATIBILITY and n.target in INPLACE_ATEN: + inplace = True + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + + return inplace diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 3489f00be..4d51e0eea 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -222,6 +222,7 @@ flop_mapping = { aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), + aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), } elementwise_flop_aten = [ diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index a152385e8..8051a753c 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,12 +1,10 @@ -from dataclasses import dataclass -from enum import auto from typing import Callable, Any, Dict, Tuple import torch from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import GraphInfo, autograd_graph_analysis, Phase -from .memory import WEIRD_OPS, activation_size +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping @@ -23,7 +21,7 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: +def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: """ Profile a Callable function with args and kwargs. @@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... # `flop_count`` serves as a global dictionary to store results. flop_count = { Phase.FORWARD: 0, - Phase.LOSS: 0, Phase.BACKWARD: 0, } @@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... kwargs_node = tree_map(get_node, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) + # do not allocate on `cpu` + if 'device' in kwargs: + kwargs['device'] = 'meta' + def unwrap(x): # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): @@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... if target not in WEIRD_OPS: def wrap(x): - return FlopTensor(x.detach().requires_grad_( - True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x + return FlopTensor( + x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x else: def wrap(x): - return FlopTensor(x.detach().requires_grad_( - False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x + return FlopTensor( + x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) @@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... tree_map(set_placeholder, kwargs) def pack(x): - if isinstance(x, FlopTensor): + if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter): x._node.meta['saved'] = True return x @@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... else: out = target(*args, **kwargs) - # If the output is not a floating point `torch.Tensor` or it does not - # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad: - phase = Phase.LOSS - loss = out.sum() - phase = Phase.BACKWARD - loss.backward() + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. + if is_autogradable(out) and out.requires_grad: + phase = Phase.BACKWARD + if isinstance(out, FlopTensor): + out._node.meta['save'] = False + grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like( + out, device='meta') + torch.autograd.backward(out, FlopTensor(grad)) graph_info = autograd_graph_analysis(subgraph) graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] @@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable: Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu - >>> output, meta_info = profile_function(func)(input, inplace=False) + >>> output, meta_info = profile_function(func)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable: args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) out = func(*args, **kwargs) - return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) + return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0) out, meta = _profile(func, *args, **kwargs) return out, meta @@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, meta = _profile(target, *args, inplace=False, **kwargs) + out, meta = _profile(target, *args, **kwargs) return out, meta return f @@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable: args = tree_map(lambda x: x.to('meta'), args) kwargs = tree_map(lambda x: x.to('meta'), kwargs) out = func(*args, **kwargs) - return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) - out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) + return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0) + out, meta = _profile(func, *args, **kwargs) return out, meta f.__name__ = module.__class__.__name__