mirror of https://github.com/hpcaitech/ColossalAI
[fx/tuning] tune performance on rotor with meta info. (#1599)
parent
a7cda6f57d
commit
cd5cf2bcc9
|
@ -1,8 +1,7 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
from torch.fx import Node
|
||||||
from torch.fx import GraphModule, Node
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.profiler import parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from .linearize import linearize
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||||
# Build table
|
# Build table
|
||||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||||
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||||
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||||
|
|
||||||
# Initialize borders of the tables for lmax-lmin = 0
|
# Initialize borders of the tables for lmax-lmin = 0
|
||||||
for m in range(mmax + 1):
|
for m in range(mmax + 1):
|
||||||
|
@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
|
||||||
return [math.ceil(value / mem_unit) for value in values]
|
return [math.ceil(value / mem_unit) for value in values]
|
||||||
|
|
||||||
|
|
||||||
def _compute_size(obj: torch.Tensor) -> int:
|
|
||||||
return obj.numel() * obj.element_size()
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_output_size(node: List[Node]) -> int:
|
|
||||||
"""Compute the output size of a node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (List[Node]): node, list of torch.fx.Node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: output size
|
|
||||||
"""
|
|
||||||
|
|
||||||
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
|
|
||||||
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_inplace(node: Node) -> bool:
|
|
||||||
"""Get the inplace argument from torch.fx.Node
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (Node): torch.fx.Node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: indicates whether this op is inplace
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_inplace = False
|
|
||||||
if node.op == "call_function":
|
|
||||||
is_inplace = node.kwargs.get("inplace", False)
|
|
||||||
elif node.op == "call_module":
|
|
||||||
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
|
|
||||||
|
|
||||||
return is_inplace
|
|
||||||
|
|
||||||
|
|
||||||
def _fwd_xbar(node: List[Node]) -> int:
|
def _fwd_xbar(node: List[Node]) -> int:
|
||||||
"""Get the forward xbar of a node
|
"""Get the forward xbar of a node
|
||||||
|
|
||||||
|
@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||||
for k, v in deps.items():
|
for k, v in deps.items():
|
||||||
if v > 0:
|
if v > 0:
|
||||||
deps_size += k.meta['bwd_mem_out']
|
deps_size += k.meta['bwd_mem_out']
|
||||||
|
if v == float('-inf'):
|
||||||
|
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
|
||||||
|
|
||||||
return deps_size
|
return deps_size
|
||||||
|
|
||||||
bwd_mem_tmp = 0
|
bwd_mem_tmp = 0
|
||||||
deps = {}
|
deps = {}
|
||||||
|
|
||||||
# add all the users for last node into deps,
|
|
||||||
# as those nodes' gradient out will be stored in memory
|
|
||||||
for child in node[-1].users:
|
|
||||||
deps[child] = 1
|
|
||||||
for n in reversed(node):
|
for n in reversed(node):
|
||||||
|
deps[n] = len(n.all_input_nodes)
|
||||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
||||||
|
|
||||||
deps[n] = len(n.all_input_nodes)
|
|
||||||
for child in n.users:
|
for child in n.users:
|
||||||
if child in deps:
|
if child in deps:
|
||||||
deps[child] -= 1
|
deps[child] -= 1
|
||||||
|
if deps[child] <= 0:
|
||||||
for key in list(deps.keys()):
|
deps[child] = float('-inf') # free
|
||||||
if deps[key] == 0:
|
|
||||||
del deps[key]
|
|
||||||
|
|
||||||
return bwd_mem_tmp
|
return bwd_mem_tmp
|
||||||
|
|
||||||
|
|
||||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
|
||||||
|
|
||||||
fwd_time = []
|
fwd_time = []
|
||||||
bwd_time = []
|
bwd_time = []
|
||||||
|
xbar_sizes = [activation_size(input)]
|
||||||
if isinstance(data, torch.Tensor):
|
x_sizes = [activation_size(input)]
|
||||||
xbar_sizes = [_compute_size(data)]
|
|
||||||
x_sizes = [_compute_size(data)]
|
|
||||||
elif isinstance(data, list) or isinstance(data, tuple):
|
|
||||||
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
|
|
||||||
x_sizes = [sum([_compute_size(obj) for obj in data])]
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
|
||||||
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
|
||||||
|
|
||||||
# currently we can't get the temp memory needed in fwd
|
# currently we can't get the temp memory needed in fwd
|
||||||
tmp_fwd = [0] * len(node_list)
|
tmp_fwd = [0] * len(node_list)
|
||||||
tmp_bwd = []
|
tmp_bwd = []
|
||||||
|
@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
fwd_time.append(_fwd_time(node))
|
fwd_time.append(_fwd_time(node))
|
||||||
bwd_time.append(_bwd_time(node))
|
bwd_time.append(_bwd_time(node))
|
||||||
x_sizes.append(_compute_output_size(node))
|
x_sizes.append(node[-1].meta['fwd_mem_out'])
|
||||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||||
|
|
||||||
# if a node with only one inplace op, we need to let x_bar = 0
|
|
||||||
if len(node) == 1 and _get_inplace(node[0]):
|
|
||||||
xbar_sizes[-1] = 0
|
|
||||||
|
|
||||||
bwd_time.append(0)
|
bwd_time.append(0)
|
||||||
|
|
||||||
# currently we view loss backward temp as zero
|
# currently we view loss backward temp as zero
|
||||||
|
@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
|
||||||
mem_limit: int,
|
mem_limit: int,
|
||||||
mem_slots: int = 500,
|
mem_slots: int = 500,
|
||||||
cnode: List[str] = None,
|
cnode: List[str] = None,
|
||||||
eps: float = 0.02) -> ColoGraphModule:
|
eps: float = 0.0) -> ColoGraphModule:
|
||||||
"""solver that automatically find activation checkpoint in rotor's manner
|
"""solver that automatically find activation checkpoint in rotor's manner
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
|
||||||
mem_limit (int): memory budget in Byte.
|
mem_limit (int): memory budget in Byte.
|
||||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||||
eps (float): epsilon for memory decay. Defaults to 0.02
|
eps (float): epsilon for memory decay. Defaults to 0.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
|
from colossalai.fx.profiler import is_inplace
|
||||||
|
|
||||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||||
# unchanged throughout the whole model, it will be used several times by
|
# unchanged throughout the whole model, it will be used several times by
|
||||||
|
@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Node]]: List of list, each inside list of Node presents
|
List[List[Node]]: List of list, each inside list of Node presents
|
||||||
the actual 'node' in linearized manner.
|
the actual 'node' in linearized manner.
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
We merge the inplace ops into the previous node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_sink() -> bool:
|
def _is_sink() -> bool:
|
||||||
|
@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||||
bool
|
bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return not sum([v for _, v in deps.items()])
|
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||||
|
|
||||||
# make sure that item in cnode is valid
|
# make sure that item in cnode is valid
|
||||||
if cnode:
|
if cnode:
|
||||||
|
|
|
@ -7,4 +7,4 @@ else:
|
||||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
||||||
|
|
||||||
from .dataflow import GraphInfo
|
from .dataflow import GraphInfo
|
||||||
from .memory import parameter_size, activation_size
|
from .memory import parameter_size, activation_size, is_inplace
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from .memory import activation_size
|
from .memory import activation_size, is_inplace
|
||||||
|
from . import META_COMPATIBILITY
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
LOSS = 1
|
BACKWARD = 1
|
||||||
BACKWARD = 2
|
PLACEHOLDER = 2
|
||||||
PLACEHOLDER = 3
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
def _peak_memory(deps: Dict[Node, int]):
|
def _peak_memory(deps: Dict[Node, int]):
|
||||||
peak_mem = 0
|
peak_mem = 0
|
||||||
for k, v in deps.items():
|
for k, v in deps.items():
|
||||||
if v > 0:
|
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
||||||
peak_mem += activation_size(k.meta['out'])
|
peak_mem += activation_size(k.meta['out'])
|
||||||
|
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
||||||
|
peak_mem -= activation_size(k.meta['out'])
|
||||||
return peak_mem
|
return peak_mem
|
||||||
|
|
||||||
# deps is used to track all the memory dependencies of the graph.
|
# deps is used to track all the memory dependencies of the graph.
|
||||||
|
@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
|
|
||||||
for n in graph.nodes:
|
for n in graph.nodes:
|
||||||
n: Node
|
n: Node
|
||||||
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
|
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
||||||
# A forward tensor who is marked `save` but is not
|
# A forward tensor who is marked `save` but is not
|
||||||
# an input to `loss` should be saved during forward.
|
# an input to `loss` should be saved during forward.
|
||||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||||
|
@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||||
elif is_phase(n, Phase.BACKWARD):
|
elif is_phase(n, Phase.BACKWARD):
|
||||||
if len(n.users):
|
if len(n.users):
|
||||||
# liveness analysis is only used in backward
|
|
||||||
deps[n] = len(n.users)
|
|
||||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||||
for input_n in n.all_input_nodes:
|
|
||||||
if input_n in deps:
|
|
||||||
deps[input_n] -= 1
|
|
||||||
else:
|
else:
|
||||||
|
# TODO: some of the bwd_mem_out might be model parameters.
|
||||||
# basically a backward node without user is a `grad_out` node
|
# basically a backward node without user is a `grad_out` node
|
||||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||||
|
for input_n in n.all_input_nodes:
|
||||||
|
if input_n in deps:
|
||||||
|
deps[input_n] -= 1
|
||||||
|
if deps[input_n] <= 0:
|
||||||
|
deps[input_n] = float('-inf')
|
||||||
return graph_info
|
return graph_info
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
from torch.fx import Node
|
||||||
from typing import Union, Dict, List, Tuple
|
from typing import Union, Dict, List, Tuple
|
||||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||||
from . import META_COMPATIBILITY
|
from . import META_COMPATIBILITY
|
||||||
|
|
||||||
__all__ = ['activation_size', 'parameter_size']
|
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
if META_COMPATIBILITY:
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
@ -21,6 +22,7 @@ if META_COMPATIBILITY:
|
||||||
aten.bernoulli_.float,
|
aten.bernoulli_.float,
|
||||||
|
|
||||||
# inplace reshaping
|
# inplace reshaping
|
||||||
|
aten.copy_.default,
|
||||||
aten.detach.default,
|
aten.detach.default,
|
||||||
aten.t.default,
|
aten.t.default,
|
||||||
aten.transpose.int,
|
aten.transpose.int,
|
||||||
|
@ -28,7 +30,17 @@ if META_COMPATIBILITY:
|
||||||
aten._unsafe_view.default,
|
aten._unsafe_view.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS']
|
NORMALIZATION_ATEN = [
|
||||||
|
aten.native_batch_norm.default,
|
||||||
|
aten.native_layer_norm.default,
|
||||||
|
# aten.max_pool2d_with_indices.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
CLONE_ATEN = [
|
||||||
|
aten.clone.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO fill out the inplace ops
|
# TODO fill out the inplace ops
|
||||||
|
@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
||||||
for param in mod.parameters():
|
for param in mod.parameters():
|
||||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||||
return param_size
|
return param_size
|
||||||
|
|
||||||
|
|
||||||
|
def is_inplace(n: Node):
|
||||||
|
"""Get the inplace argument from torch.fx.Node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): torch.fx.Node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: indicates whether this op is inplace
|
||||||
|
"""
|
||||||
|
inplace = False
|
||||||
|
if n.op == "call_function":
|
||||||
|
inplace = n.kwargs.get("inplace", False)
|
||||||
|
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
|
||||||
|
inplace = True
|
||||||
|
elif n.op == "call_module":
|
||||||
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||||
|
|
||||||
|
return inplace
|
||||||
|
|
|
@ -222,6 +222,7 @@ flop_mapping = {
|
||||||
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
|
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
elementwise_flop_aten = [
|
elementwise_flop_aten = [
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import auto
|
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
||||||
from .memory import WEIRD_OPS, activation_size
|
from .memory import WEIRD_OPS
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
|
|
||||||
|
@ -23,7 +21,7 @@ def is_autogradable(x):
|
||||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||||
|
|
||||||
|
|
||||||
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
|
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
"""
|
"""
|
||||||
Profile a Callable function with args and kwargs.
|
Profile a Callable function with args and kwargs.
|
||||||
|
|
||||||
|
@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||||
# `flop_count`` serves as a global dictionary to store results.
|
# `flop_count`` serves as a global dictionary to store results.
|
||||||
flop_count = {
|
flop_count = {
|
||||||
Phase.FORWARD: 0,
|
Phase.FORWARD: 0,
|
||||||
Phase.LOSS: 0,
|
|
||||||
Phase.BACKWARD: 0,
|
Phase.BACKWARD: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||||
kwargs_node = tree_map(get_node, kwargs)
|
kwargs_node = tree_map(get_node, kwargs)
|
||||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||||
|
|
||||||
|
# do not allocate on `cpu`
|
||||||
|
if 'device' in kwargs:
|
||||||
|
kwargs['device'] = 'meta'
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||||
|
@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||||
if target not in WEIRD_OPS:
|
if target not in WEIRD_OPS:
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(x.detach().requires_grad_(
|
return FlopTensor(
|
||||||
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(x.detach().requires_grad_(
|
return FlopTensor(
|
||||||
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||||
|
|
||||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||||
args = tree_map(wrap, args)
|
args = tree_map(wrap, args)
|
||||||
|
@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||||
tree_map(set_placeholder, kwargs)
|
tree_map(set_placeholder, kwargs)
|
||||||
|
|
||||||
def pack(x):
|
def pack(x):
|
||||||
if isinstance(x, FlopTensor):
|
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
||||||
x._node.meta['saved'] = True
|
x._node.meta['saved'] = True
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||||
else:
|
else:
|
||||||
out = target(*args, **kwargs)
|
out = target(*args, **kwargs)
|
||||||
|
|
||||||
# If the output is not a floating point `torch.Tensor` or it does not
|
# If the output is not a floating point `torch.Tensor` or it does not
|
||||||
# requires grad, then we should not run backward for this node.
|
# requires grad, then we should not run backward for this node.
|
||||||
if is_autogradable(out) and out.requires_grad:
|
if is_autogradable(out) and out.requires_grad:
|
||||||
phase = Phase.LOSS
|
phase = Phase.BACKWARD
|
||||||
loss = out.sum()
|
if isinstance(out, FlopTensor):
|
||||||
phase = Phase.BACKWARD
|
out._node.meta['save'] = False
|
||||||
loss.backward()
|
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
||||||
|
out, device='meta')
|
||||||
|
torch.autograd.backward(out, FlopTensor(grad))
|
||||||
|
|
||||||
graph_info = autograd_graph_analysis(subgraph)
|
graph_info = autograd_graph_analysis(subgraph)
|
||||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||||
|
@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
Examples:
|
Examples:
|
||||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||||
>>> func = torch.nn.functional.relu
|
>>> func = torch.nn.functional.relu
|
||||||
>>> output, meta_info = profile_function(func)(input, inplace=False)
|
>>> output, meta_info = profile_function(func)(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||||
out, meta = _profile(func, *args, **kwargs)
|
out, meta = _profile(func, *args, **kwargs)
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
|
@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
# execute the method and return the result
|
# execute the method and return the result
|
||||||
assert isinstance(target, str), f'{target} instance is not str.'
|
assert isinstance(target, str), f'{target} instance is not str.'
|
||||||
out, meta = _profile(target, *args, inplace=False, **kwargs)
|
out, meta = _profile(target, *args, **kwargs)
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
args = tree_map(lambda x: x.to('meta'), args)
|
args = tree_map(lambda x: x.to('meta'), args)
|
||||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||||
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
out, meta = _profile(func, *args, **kwargs)
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
f.__name__ = module.__class__.__name__
|
f.__name__ = module.__class__.__name__
|
||||||
|
|
Loading…
Reference in New Issue