[fx/tuning] tune performance on rotor with meta info. (#1599)

pull/1601/head^2
Super Daniel 2 years ago committed by GitHub
parent a7cda6f57d
commit cd5cf2bcc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,7 @@
from typing import List, Tuple from typing import List, Tuple
import torch from torch.fx import Node
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import activation_size, parameter_size
import math import math
from .linearize import linearize from .linearize import linearize
from .utils import * from .utils import *
@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table # Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0 # Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1): for m in range(mmax + 1):
@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values] return [math.ceil(value / mem_unit) for value in values]
def _compute_size(obj: torch.Tensor) -> int:
return obj.numel() * obj.element_size()
def _compute_output_size(node: List[Node]) -> int:
"""Compute the output size of a node
Args:
node (List[Node]): node, list of torch.fx.Node
Returns:
int: output size
"""
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
def _get_inplace(node: Node) -> bool:
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
is_inplace = False
if node.op == "call_function":
is_inplace = node.kwargs.get("inplace", False)
elif node.op == "call_module":
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
return is_inplace
def _fwd_xbar(node: List[Node]) -> int: def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node """Get the forward xbar of a node
@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
for k, v in deps.items(): for k, v in deps.items():
if v > 0: if v > 0:
deps_size += k.meta['bwd_mem_out'] deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
return deps_size return deps_size
bwd_mem_tmp = 0 bwd_mem_tmp = 0
deps = {} deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for child in node[-1].users:
deps[child] = 1
for n in reversed(node): for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n.all_input_nodes)
for child in n.users: for child in n.users:
if child in deps: if child in deps:
deps[child] -= 1 deps[child] -= 1
if deps[child] <= 0:
for key in list(deps.keys()): deps[child] = float('-inf') # free
if deps[key] == 0:
del deps[key]
return bwd_mem_tmp return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
fwd_time = [] fwd_time = []
bwd_time = [] bwd_time = []
xbar_sizes = [activation_size(input)]
if isinstance(data, torch.Tensor): x_sizes = [activation_size(input)]
xbar_sizes = [_compute_size(data)]
x_sizes = [_compute_size(data)]
elif isinstance(data, list) or isinstance(data, tuple):
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
x_sizes = [sum([_compute_size(obj) for obj in data])]
elif isinstance(data, dict):
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
# currently we can't get the temp memory needed in fwd # currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list) tmp_fwd = [0] * len(node_list)
tmp_bwd = [] tmp_bwd = []
@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
for idx, node in enumerate(node_list): for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node)) fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node)) bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node)) x_sizes.append(node[-1].meta['fwd_mem_out'])
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_mem_tmp(node)) tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
xbar_sizes[-1] = 0
bwd_time.append(0) bwd_time.append(0)
# currently we view loss backward temp as zero # currently we view loss backward temp as zero
@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int, mem_limit: int,
mem_slots: int = 500, mem_slots: int = 500,
cnode: List[str] = None, cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule: eps: float = 0.0) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner """solver that automatically find activation checkpoint in rotor's manner
Args: Args:
@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte. mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None. cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.02 eps (float): epsilon for memory decay. Defaults to 0.0
Returns: Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute

@ -1,5 +1,6 @@
from typing import List, Any from typing import List, Any
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain # Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by # unchanged throughout the whole model, it will be used several times by
@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns: Returns:
List[List[Node]]: List of list, each inside list of Node presents List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner. the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
""" """
def _is_sink() -> bool: def _is_sink() -> bool:
@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
bool bool
""" """
return not sum([v for _, v in deps.items()]) return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid # make sure that item in cnode is valid
if cnode: if cnode:

@ -7,4 +7,4 @@ else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory import parameter_size, activation_size from .memory import parameter_size, activation_size, is_inplace

@ -1,16 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial
from typing import Dict from typing import Dict
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .memory import activation_size from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
class Phase(Enum): class Phase(Enum):
FORWARD = 0 FORWARD = 0
LOSS = 1 BACKWARD = 1
BACKWARD = 2 PLACEHOLDER = 2
PLACEHOLDER = 3
@dataclass @dataclass
@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]): def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0 peak_mem = 0
for k, v in deps.items(): for k, v in deps.items():
if v > 0: if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out']) peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
return peak_mem return peak_mem
# deps is used to track all the memory dependencies of the graph. # deps is used to track all the memory dependencies of the graph.
@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes: for n in graph.nodes:
n: Node n: Node
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)): if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not # A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward. # an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`. # If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
graph_info.fwd_mem_tmp += activation_size(n.meta['out']) graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_phase(n, Phase.BACKWARD): elif is_phase(n, Phase.BACKWARD):
if len(n.users): if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else: else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node # basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out']) graph_info.bwd_mem_out += activation_size(n.meta['out'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
return graph_info return graph_info

@ -1,9 +1,10 @@
import torch import torch
from torch.fx import Node
from typing import Union, Dict, List, Tuple from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size'] __all__ = ['activation_size', 'parameter_size', 'is_inplace']
if META_COMPATIBILITY: if META_COMPATIBILITY:
aten = torch.ops.aten aten = torch.ops.aten
@ -21,6 +22,7 @@ if META_COMPATIBILITY:
aten.bernoulli_.float, aten.bernoulli_.float,
# inplace reshaping # inplace reshaping
aten.copy_.default,
aten.detach.default, aten.detach.default,
aten.t.default, aten.t.default,
aten.transpose.int, aten.transpose.int,
@ -28,7 +30,17 @@ if META_COMPATIBILITY:
aten._unsafe_view.default, aten._unsafe_view.default,
] ]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS'] NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
else: else:
# TODO fill out the inplace ops # TODO fill out the inplace ops
@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
for param in mod.parameters(): for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

@ -222,6 +222,7 @@ flop_mapping = {
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
} }
elementwise_flop_aten = [ elementwise_flop_aten = [

@ -1,12 +1,10 @@
from dataclasses import dataclass
from enum import auto
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import GraphInfo, autograd_graph_analysis, Phase from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS, activation_size from .memory import WEIRD_OPS
from .tensor import MetaTensor from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
@ -23,7 +21,7 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point() return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
""" """
Profile a Callable function with args and kwargs. Profile a Callable function with args and kwargs.
@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results. # `flop_count`` serves as a global dictionary to store results.
flop_count = { flop_count = {
Phase.FORWARD: 0, Phase.FORWARD: 0,
Phase.LOSS: 0,
Phase.BACKWARD: 0, Phase.BACKWARD: 0,
} }
@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
kwargs_node = tree_map(get_node, kwargs) kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node) node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on `cpu`
if 'device' in kwargs:
kwargs['device'] = 'meta'
def unwrap(x): def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
if target not in WEIRD_OPS: if target not in WEIRD_OPS:
def wrap(x): def wrap(x):
return FlopTensor(x.detach().requires_grad_( return FlopTensor(
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
else: else:
def wrap(x): def wrap(x):
return FlopTensor(x.detach().requires_grad_( return FlopTensor(
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
# Basically, we need to detach the args and kwargs from the outer graph. # Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args) args = tree_map(wrap, args)
@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
tree_map(set_placeholder, kwargs) tree_map(set_placeholder, kwargs)
def pack(x): def pack(x):
if isinstance(x, FlopTensor): if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
x._node.meta['saved'] = True x._node.meta['saved'] = True
return x return x
@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
else: else:
out = target(*args, **kwargs) out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not # If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node. # requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad: if is_autogradable(out) and out.requires_grad:
phase = Phase.LOSS phase = Phase.BACKWARD
loss = out.sum() if isinstance(out, FlopTensor):
phase = Phase.BACKWARD out._node.meta['save'] = False
loss.backward() grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
out, device='meta')
torch.autograd.backward(out, FlopTensor(grad))
graph_info = autograd_graph_analysis(subgraph) graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
Examples: Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta') >>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu >>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input, inplace=False) >>> output, meta_info = profile_function(func)(input)
""" """
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs) out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, **kwargs) out, meta = _profile(func, *args, **kwargs)
return out, meta return out, meta
@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result # execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.' assert isinstance(target, str), f'{target} instance is not str.'
out, meta = _profile(target, *args, inplace=False, **kwargs) out, meta = _profile(target, *args, **kwargs)
return out, meta return out, meta
return f return f
@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
args = tree_map(lambda x: x.to('meta'), args) args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs) kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs) out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) out, meta = _profile(func, *args, **kwargs)
return out, meta return out, meta
f.__name__ = module.__class__.__name__ f.__name__ = module.__class__.__name__

Loading…
Cancel
Save