from functools import partial from typing import Callable, Any, Dict, Tuple import torch from torch.nn.parameter import Parameter from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo from .memory import activation_size, parameter_size from .constant import ALIAS_ATEN from .tensor import MetaTensor from .opcount import flop_mapping import time __all__ = ['profile_function', 'profile_module', 'profile_method'] # super-dainiu: this cache should be global, otherwise it cannot # track duplicated tensors between nodes cache = set() # a global identifier for inplace ops do_not_cache = False def normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() def detach_variables(x): if isinstance(x, torch.Tensor): requires_grad = x.requires_grad x = x.detach() x.requires_grad = requires_grad return x def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 To profile the actual forward memory, we first run target in the context torch.no_grad() to get the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory by memory allocated minus the fwd_mem_out. To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size of args and kwargs). We also add time stamps to profile the real forward and backward time. Args: target (Callable): A Callable function args (Any): Arguments kwargs (Any): Arguments Returns: Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward time. """ graphinfo = GraphInfo() # detach input from the graph args = tree_map(detach_variables, args) kwargs = tree_map(detach_variables, kwargs) if isinstance(target, str): # args[0] is the `self` object for this method call self_obj, *args_tail = args # calculate fwd_mem_out mem_stamp0 = torch.cuda.memory_allocated() with torch.no_grad(): out = getattr(self_obj, target)(*args_tail, **kwargs) mem_stamp1 = torch.cuda.memory_allocated() graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 del out # calculate fwd_mem_tmp & fwd_time mem_stamp0 = torch.cuda.memory_allocated() fwd_time0 = time.time() out = getattr(self_obj, target)(*args_tail, **kwargs) fwd_time1 = time.time() graphinfo.fwd_time = fwd_time1 - fwd_time0 mem_stamp1 = torch.cuda.memory_allocated() graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out # calculate bwd_mem_tmp & bwd_time grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() bwd_time0 = time.time() torch.autograd.backward(out, grad_tensors=grad_tensors) bwd_time1 = time.time() graphinfo.bwd_time = bwd_time1 - bwd_time0 mem_stamp1 = torch.cuda.max_memory_allocated() # calculate bwd memory stats # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out else: # calculate fwd_mem_out mem_stamp0 = torch.cuda.memory_allocated() with torch.no_grad(): out = target(*args, **kwargs) mem_stamp1 = torch.cuda.memory_allocated() graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 del out # calculate fwd_mem_tmp & fwd_time mem_stamp0 = torch.cuda.memory_allocated() fwd_time0 = time.time() out = target(*args, **kwargs) fwd_time1 = time.time() graphinfo.fwd_time = fwd_time1 - fwd_time0 mem_stamp1 = torch.cuda.memory_allocated() graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out # calculate bwd_mem_tmp & bwd_time grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() bwd_time0 = time.time() torch.autograd.backward(out, grad_tensors=grad_tensors) bwd_time1 = time.time() graphinfo.bwd_time = bwd_time1 - bwd_time0 mem_stamp1 = torch.cuda.max_memory_allocated() # calculate bwd memory stats # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out return tree_map(detach_variables, out), graphinfo def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """ Profile a Callable function with args and kwargs on meta devices. Args: target (Callable): A Callable function args (Any): Argument kwargs (Any): Argument Returns: out (Tuple[Any, ...]): The argument value that was retrieved. meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # This subgraph traces aten level ops inside one node. subgraph = Graph() # `flop_count`` serves as a global dictionary to store results. flop_count = { Phase.FORWARD: 0, Phase.BACKWARD: 0, } # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. # This makes sure we can analyze the dependencies of memory, and # decide which forward intermediate results should be kept until # backward is executed. # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): _node: Node = None def __repr__(self): if self.grad_fn: return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})" return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) out = super().__torch_dispatch__(func, types, args, kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) node.meta['phase'] = phase # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during # `Phase.FORWARD` if phase == Phase.FORWARD: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: node.meta['phase'] = Phase.PLACEHOLDER # TODO(yby): specify `saved_tensors` for backward memory estimation node.meta['saved_tensor'] = [] if phase == Phase.BACKWARD: node.meta['saved_tensor'] = normalize_tuple(out) def wrap(x): if isinstance(x, MetaTensor): x = FlopTensor(x) x._node = node return x out = tree_map(wrap, out) return out def wrap(x): if isinstance(x, torch.Tensor): x = FlopTensor(x) if is_autogradable(x): x.requires_grad_(True) x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['saved_tensor'] = [] return x # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) def pack(x): global cache, do_not_cache if isinstance(x, FlopTensor) and not x._tensor.uuid in cache: tensor = x._tensor.detach() tensor.uuid = x._tensor.uuid x._node.meta['saved_tensor'] += [tensor] if not do_not_cache: cache.add(x._tensor.uuid) return x def unpack(x): return x # `phase` will mark the phase of autograd from outside scope. phase = Phase.FORWARD # mark saved tensors with saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): if isinstance(target, str): # args[0] is the `self` object for this method call self_obj, *args_tail = args out = getattr(self_obj, target)(*args_tail, **kwargs) else: out = target(*args, **kwargs) # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))): grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)] phase = Phase.BACKWARD torch.autograd.backward( out, grad_out, ) graph_info = autograd_graph_analysis(subgraph) graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] def extract_tensor(x: Any): if isinstance(x, MetaTensor): tensor = x._tensor.detach() tensor.uuid = x._tensor.uuid return tensor return x graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) def unwrap(x): return MetaTensor(x) if isinstance(x, torch.Tensor) else x return tree_map(unwrap, out), graph_info def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu >>> output, meta_info = profile_function(func)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # find the grad for parameter in args and kwargs param_size = 0 def get_param_size(x): nonlocal param_size if isinstance(x, Parameter): param_size += activation_size(x) tree_map(get_param_size, args) tree_map(get_param_size, kwargs) # If there is an argument that this `call_function` is inplace, we should # still run the profiling but discard some results regarding `target` global do_not_cache inplace = kwargs.get('inplace', False) if inplace or target in [torch.nn.functional.relu]: do_not_cache = True kwargs['inplace'] = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) # currently we set the fwd_mem_tmp of ReLU to zero if target in [torch.nn.functional.relu]: meta.fwd_in = [] meta.fwd_tmp = [] meta.bwd_mem_out = 0 meta.fwd_mem_tmp = 0 else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: kwargs['inplace'] = True do_not_cache = False meta.bwd_mem_out -= param_size return out, meta f.__name__ = target.__name__ func = target return f def profile_method(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' if device == 'meta': out, meta = _profile_meta(target, *args, **kwargs) else: out, meta = _profile_concrete(target, *args, **kwargs) return out, meta return f def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) >>> output, meta_info = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # calculate parameter size param_size = parameter_size(module) # If there is an argument that this `call_module` is inplace, we should # still run the profiling but discard some results regarding `module`. global do_not_cache inplace = getattr(module, 'inplace', False) if inplace or type(module) in [torch.nn.ReLU]: do_not_cache = True module.inplace = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) # currently we set the fwd_tmp of ReLU to [] if type(module) in [torch.nn.ReLU]: meta.fwd_in = [] meta.fwd_tmp = [] meta.bwd_mem_out = 0 else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: module.inplace = True do_not_cache = False # grad for param will not be counted meta.bwd_mem_out -= param_size return out, meta f.__name__ = module.__class__.__name__ func = module.forward return f