[fx] refactor code for profiler / enable fake tensor movement. (#1646)

* [fx/profiling] provide summary for MetaInfoProp.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx] optimize table repr.

* [fx] optimize table repr.

* [fx] refactor code for profiler.

* [fx] add docstring.

* [fx] add docstring.

* [fx] skip test.

* [fx] redo.

* [fx] redo.

* [fx] fix import error for torch11.

* [fx] fix import error for torch11.
pull/1654/head
Super Daniel 2022-09-27 10:26:52 +08:00 committed by GitHub
parent 5d0fdb9cb4
commit 6135e178b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 67 deletions

View File

@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
INF = float("inf")
@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm)
# prepare data
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)
chain = _normalize_flops(chain, flops)

View File

@ -2,12 +2,12 @@ from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai import META_COMPATIBILITY
# this is the python compute table code from rotor
@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
node_list = linearize(gm, cnode)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)

View File

@ -14,6 +14,7 @@ if META_COMPATIBILITY:
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [

View File

@ -37,9 +37,28 @@ def detach(x):
x.requires_grad_(requires_grad)
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs.
Profile a Callable function with args and kwargs on concrete devices.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Raises:
NotImplementedError: TODO(yby)
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
raise NotImplementedError
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on meta devices.
Args:
target (Callable): A Callable function
@ -67,7 +86,7 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
_node: Node
_node: Node = None
def __repr__(self):
if self.grad_fn:
@ -76,34 +95,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def get_node(x):
return None if not hasattr(x, '_node') else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on physical devices
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
out = super().__torch_dispatch__(func, types, args, kwargs)
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=WHATEVER but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['phase'] = phase
@ -114,52 +111,41 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER
# TODO: specify `saved_tensors` for backward memory estimation
# TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = []
if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out)
def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
def set_node(x):
x._node = node
if isinstance(x, MetaTensor):
x = FlopTensor(x)
x._node = node
return x
out = tree_map(wrap, out)
tree_map(set_node, out)
return out
def wrap(x):
fake_device = None
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
detach(x)
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def set_placeholder(x):
if isinstance(x, FlopTensor):
if isinstance(x, torch.Tensor):
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['saved_tensor'] = []
detach(x)
return x
tree_map(set_placeholder, args)
tree_map(set_placeholder, kwargs)
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def pack(x):
global cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
x._node.meta['saved_tensor'] += [x._tensor]
x._node.meta['saved_tensor'] += [x]
cache.add(x._tensor.data_ptr)
return x
@ -191,16 +177,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
graph_info.fwd_mem_out = activation_size(out)
def unwrap(x):
if isinstance(x, FlopTensor):
fake_device = x.device
x = x._tensor
detach(x)
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info
def profile_function(target: 'Target') -> Callable:
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@ -222,7 +204,10 @@ def profile_function(target: 'Target') -> Callable:
inplace = kwargs.get('inplace', False)
if inplace:
kwargs['inplace'] = False
out, meta = _profile(func, *args, **kwargs)
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False
@ -234,7 +219,7 @@ def profile_function(target: 'Target') -> Callable:
return f
def profile_method(target: 'Target') -> Callable:
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@ -243,13 +228,16 @@ def profile_method(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out, meta = _profile(target, *args, **kwargs)
if device == 'meta':
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
return out, meta
return f
def profile_module(module: torch.nn.Module) -> Callable:
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
@ -271,7 +259,10 @@ def profile_module(module: torch.nn.Module) -> Callable:
inplace = getattr(module, 'inplace', False)
if inplace:
module.inplace = False
out, meta = _profile(func, *args, **kwargs)
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
# is the only inplace activation function that discard its input.

View File

@ -1,5 +1,9 @@
from copy import deepcopy
from typing import Optional, Union, overload
import torch
from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device
from functools import singledispatchmethod
__all__ = ['MetaTensor']
@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
if isinstance(elem, MetaTensor):
fake_device = elem.device if fake_device is None else fake_device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
@singledispatchmethod
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
return super().to(*args, **kwargs)
@to.register
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)
@to.register
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)