Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

409 lines
15 KiB

import time
from functools import partial
from typing import Any, Callable, Dict, Tuple
import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
__all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()
# a global identifier for inplace ops
do_not_cache = False
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def detach_variables(x):
if isinstance(x, torch.Tensor):
requires_grad = x.requires_grad
x = x.detach()
x.requires_grad = requires_grad
return x
@compatibility(is_backward_compatible=True)
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory
by memory allocated minus the fwd_mem_out.
To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then
find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size
of args and kwargs).
We also add time stamps to profile the real forward and backward time.
Args:
target (Callable): A Callable function
args (Any): Arguments
kwargs (Any): Arguments
Returns:
Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward
time.
"""
graphinfo = GraphInfo()
# detach input from the graph
args = tree_map(detach_variables, args)
kwargs = tree_map(detach_variables, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = getattr(self_obj, target)(*args_tail, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = getattr(self_obj, target)(*args_tail, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
else:
# calculate fwd_mem_out
mem_stamp0 = torch.cuda.memory_allocated()
with torch.no_grad():
out = target(*args, **kwargs)
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
del out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0 = torch.cuda.memory_allocated()
fwd_time0 = time.time()
out = target(*args, **kwargs)
fwd_time1 = time.time()
graphinfo.fwd_time = fwd_time1 - fwd_time0
mem_stamp1 = torch.cuda.memory_allocated()
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
bwd_time0 = time.time()
torch.autograd.backward(out, grad_tensors=grad_tensors)
bwd_time1 = time.time()
graphinfo.bwd_time = bwd_time1 - bwd_time0
mem_stamp1 = torch.cuda.max_memory_allocated()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out
return tree_map(detach_variables, out), graphinfo
@compatibility(is_backward_compatible=False)
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on meta devices.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# This subgraph traces aten level ops inside one node.
subgraph = Graph()
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Phase.FORWARD: 0,
Phase.BACKWARD: 0,
}
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
# decide which forward intermediate results should be kept until
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node = None
def __repr__(self):
if self.grad_fn:
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD:
node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
x = FlopTensor(x)
x._node = node
return x
out = tree_map(wrap, out)
return out
def wrap(x):
if isinstance(x, torch.Tensor):
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
x._node = subgraph.create_node(
"placeholder",
"placeholder",
(subgraph._root,),
name=subgraph._graph_namespace.create_name("input", x._tensor),
)
x._node.meta["phase"] = Phase.PLACEHOLDER
x._node.meta["saved_tensor"] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def pack(x):
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
def unpack(x):
return x
# `phase` will mark the phase of autograd from outside scope.
phase = Phase.FORWARD
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
phase = Phase.BACKWARD
torch.autograd.backward(
out,
grad_out,
)
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
return tensor
if not isinstance(x, torch.finfo):
return x
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
def unwrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info
@compatibility(is_backward_compatible=True)
def profile_function(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# find the grad for parameter in args and kwargs
param_size = 0
def get_param_size(x):
nonlocal param_size
if isinstance(x, Parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)
tree_map(get_param_size, kwargs)
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
kwargs["inplace"] = False
if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs["inplace"] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
meta.bwd_mem_out -= param_size
return out, meta
f.__name__ = target.__name__
func = target
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f"{target} instance is not str."
if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
return out, meta
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, meta_info = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# calculate parameter size
param_size = parameter_size(module)
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
global do_not_cache
inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
# grad for param will not be counted
meta.bwd_mem_out -= param_size
return out, meta
f.__name__ = module.__class__.__name__
func = module.forward
return f