[fx/tuning] tune performance on rotor with meta info. (#1599)

pull/1601/head^2
Super Daniel 2022-09-15 14:46:36 +08:00 committed by GitHub
parent a7cda6f57d
commit cd5cf2bcc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 107 deletions

View File

@ -1,8 +1,7 @@
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
from colossalai.fx.profiler import activation_size, parameter_size
import math
from .linearize import linearize
from .utils import *
@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
def _compute_size(obj: torch.Tensor) -> int:
return obj.numel() * obj.element_size()
def _compute_output_size(node: List[Node]) -> int:
"""Compute the output size of a node
Args:
node (List[Node]): node, list of torch.fx.Node
Returns:
int: output size
"""
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
def _get_inplace(node: Node) -> bool:
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
is_inplace = False
if node.op == "call_function":
is_inplace = node.kwargs.get("inplace", False)
elif node.op == "call_module":
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
return is_inplace
def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
return deps_size
bwd_mem_tmp = 0
deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1
for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
if deps[child] <= 0:
deps[child] = float('-inf') # free
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
fwd_time = []
bwd_time = []
if isinstance(data, torch.Tensor):
xbar_sizes = [_compute_size(data)]
x_sizes = [_compute_size(data)]
elif isinstance(data, list) or isinstance(data, tuple):
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
x_sizes = [sum([_compute_size(obj) for obj in data])]
elif isinstance(data, dict):
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_bwd = []
@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
x_sizes.append(node[-1].meta['fwd_mem_out'])
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
xbar_sizes[-1] = 0
bwd_time.append(0)
# currently we view loss backward temp as zero
@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule:
eps: float = 0.0) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.02
eps (float): epsilon for memory decay. Defaults to 0.0
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute

View File

@ -1,5 +1,6 @@
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def _is_sink() -> bool:
@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
bool
"""
return not sum([v for _, v in deps.items()])
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if cnode:

View File

@ -7,4 +7,4 @@ else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .dataflow import GraphInfo
from .memory import parameter_size, activation_size
from .memory import parameter_size, activation_size, is_inplace

View File

@ -1,16 +1,17 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
class Phase(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
PLACEHOLDER = 3
BACKWARD = 1
PLACEHOLDER = 2
@dataclass
@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0:
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
return graph_info

View File

@ -1,9 +1,10 @@
import torch
from torch.fx import Node
from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size']
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
if META_COMPATIBILITY:
aten = torch.ops.aten
@ -21,6 +22,7 @@ if META_COMPATIBILITY:
aten.bernoulli_.float,
# inplace reshaping
aten.copy_.default,
aten.detach.default,
aten.t.default,
aten.transpose.int,
@ -28,7 +30,17 @@ if META_COMPATIBILITY:
aten._unsafe_view.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS']
NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

View File

@ -222,6 +222,7 @@ flop_mapping = {
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
}
elementwise_flop_aten = [

View File

@ -1,12 +1,10 @@
from dataclasses import dataclass
from enum import auto
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS, activation_size
from .memory import WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
@ -23,7 +21,7 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
"""
Profile a Callable function with args and kwargs.
@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Phase.FORWARD: 0,
Phase.LOSS: 0,
Phase.BACKWARD: 0,
}
@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on `cpu`
if 'device' in kwargs:
kwargs['device'] = 'meta'
def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(x.detach().requires_grad_(
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
return FlopTensor(
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(x.detach().requires_grad_(
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
tree_map(set_placeholder, kwargs)
def pack(x):
if isinstance(x, FlopTensor):
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
x._node.meta['saved'] = True
return x
@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
else:
out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
phase = Phase.LOSS
loss = out.sum()
phase = Phase.BACKWARD
loss.backward()
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
phase = Phase.BACKWARD
if isinstance(out, FlopTensor):
out._node.meta['save'] = False
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
out, device='meta')
torch.autograd.backward(out, FlopTensor(grad))
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input, inplace=False)
>>> output, meta_info = profile_function(func)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, **kwargs)
return out, meta
@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out, meta = _profile(target, *args, inplace=False, **kwargs)
out, meta = _profile(target, *args, **kwargs)
return out, meta
return f
@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, **kwargs)
return out, meta
f.__name__ = module.__class__.__name__