ColossalAI/colossalai/fx/profiler/profiler.py

231 lines
7.8 KiB
Python

from dataclasses import dataclass
from enum import auto
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, Stage
from .memory import WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
__all__ = ['profile_function', 'profile_module', 'profile_method']
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
"""
Profile a Callable function with args and kwargs.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# This subgraph traces aten level ops inside one node.
subgraph = Graph()
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Stage.FORWARD: 0,
Stage.LOSS: 0,
Stage.BACKWARD: 0,
}
# `stage` will mark the stage of autograd from outside scope.
stage = Stage.FORWARD
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
# decide which forward intermediate results should be kept until
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node
def __repr__(self):
if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def get_node(x):
return None if not hasattr(x, '_node') else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = FlopTensor(x.to('meta'))
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out)
node.meta['stage'] = stage
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
def set_node(x):
x._node = node
out = tree_map(wrap, out)
tree_map(set_node, out)
return out
# `WEIRD_OPS` are tough to handle because they don't accept autograd
# on meta tensor.
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(x.detach().requires_grad_(
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(x.detach().requires_grad_(
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def set_placeholder(x):
if isinstance(x, FlopTensor):
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['stage'] = Stage.PLACEHOLDER
x._node.meta['out'] = (x._tensor,)
tree_map(set_placeholder, args)
tree_map(set_placeholder, kwargs)
def pack(x):
if isinstance(x, FlopTensor):
x._node.meta['saved'] = True
return x
def unpack(x):
return x
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
stage = Stage.LOSS
loss = out.sum()
stage = Stage.BACKWARD
loss.backward()
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
return tree_map(unwrap, out), graph_info
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling.
out, meta = _profile(func, *args, **kwargs)
return out, meta
f.__name__ = target.__name__
func = target
return f
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out, meta = _profile(target, *args, inplace=False, **kwargs)
return out, meta
return f
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, meta_info = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling.
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, meta
f.__name__ = module.__class__.__name__
func = module.forward
return f