[fx] refactor code for profiler / enable fake tensor movement. (#1646)

* [fx/profiling] provide summary for MetaInfoProp.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx] optimize table repr.

* [fx] optimize table repr.

* [fx] refactor code for profiler.

* [fx] add docstring.

* [fx] add docstring.

* [fx] skip test.

* [fx] redo.

* [fx] redo.

* [fx] fix import error for torch11.

* [fx] fix import error for torch11.
pull/1654/head
Super Daniel 2022-09-27 10:26:52 +08:00 committed by GitHub
parent 5d0fdb9cb4
commit 6135e178b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 67 deletions

View File

@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
INF = float("inf") INF = float("inf")
@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm) mem_limit -= parameter_size(gm)
# prepare data # prepare data
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data) chain: Chain = _construct_chain(node_list, data)
chain = _normalize_flops(chain, flops) chain = _normalize_flops(chain, flops)

View File

@ -2,12 +2,12 @@ from typing import List, Tuple
from torch.fx import Node from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math import math
from .linearize import linearize from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai import META_COMPATIBILITY
# this is the python compute table code from rotor # this is the python compute table code from rotor
@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
node_list = linearize(gm, cnode) node_list = linearize(gm, cnode)
mem_unit = mem_limit * (1.0 - eps) // mem_slots mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device) if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data) chain: Chain = _construct_chain(node_list, data)

View File

@ -14,6 +14,7 @@ if META_COMPATIBILITY:
aten.transpose.int, aten.transpose.int,
aten.view.default, aten.view.default,
aten._unsafe_view.default, aten._unsafe_view.default,
aten._reshape_alias.default,
] ]
INPLACE_NEW = [ INPLACE_NEW = [

View File

@ -37,9 +37,28 @@ def detach(x):
x.requires_grad_(requires_grad) x.requires_grad_(requires_grad)
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
""" """
Profile a Callable function with args and kwargs. Profile a Callable function with args and kwargs on concrete devices.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Raises:
NotImplementedError: TODO(yby)
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
raise NotImplementedError
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on meta devices.
Args: Args:
target (Callable): A Callable function target (Callable): A Callable function
@ -67,7 +86,7 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
# Hopefully, this attempt will provide a better estimation of memory. # Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor): class FlopTensor(MetaTensor):
_node: Node _node: Node = None
def __repr__(self): def __repr__(self):
if self.grad_fn: if self.grad_fn:
@ -76,34 +95,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
def get_node(x): kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
return None if not hasattr(x, '_node') else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node) node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on physical devices out = super().__torch_dispatch__(func, types, args, kwargs)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=WHATEVER but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['phase'] = phase node.meta['phase'] = phase
@ -114,52 +111,41 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER node.meta['phase'] = Phase.PLACEHOLDER
# TODO: specify `saved_tensors` for backward memory estimation # TODO(yby): specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = [] node.meta['saved_tensor'] = []
if phase == Phase.BACKWARD: if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out) node.meta['saved_tensor'] = normalize_tuple(out)
def wrap(x): def wrap(x):
if isinstance(x, torch.Tensor): if isinstance(x, MetaTensor):
nonlocal fake_device x = FlopTensor(x)
if not x.is_meta: x._node = node
x = x.to(torch.device('meta')) return x
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
def set_node(x):
x._node = node
out = tree_map(wrap, out) out = tree_map(wrap, out)
tree_map(set_node, out)
return out return out
def wrap(x): def wrap(x):
fake_device = None if isinstance(x, torch.Tensor):
if isinstance(x, MetaTensor): x = FlopTensor(x)
fake_device = x.device if is_autogradable(x):
x = x._tensor x.requires_grad_(True)
detach(x)
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def set_placeholder(x):
if isinstance(x, FlopTensor):
x._node = subgraph.create_node('placeholder', x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,), 'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor)) name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['saved_tensor'] = [] x._node.meta['saved_tensor'] = []
detach(x)
return x
tree_map(set_placeholder, args) # Basically, we need to detach the args and kwargs from the outer graph.
tree_map(set_placeholder, kwargs) args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
def pack(x): def pack(x):
global cache global cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache: if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
x._node.meta['saved_tensor'] += [x._tensor] x._node.meta['saved_tensor'] += [x]
cache.add(x._tensor.data_ptr) cache.add(x._tensor.data_ptr)
return x return x
@ -191,16 +177,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
graph_info.fwd_mem_out = activation_size(out) graph_info.fwd_mem_out = activation_size(out)
def unwrap(x): def unwrap(x):
if isinstance(x, FlopTensor): return MetaTensor(x) if isinstance(x, torch.Tensor) else x
fake_device = x.device
x = x._tensor
detach(x)
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info return tree_map(unwrap, out), graph_info
def profile_function(target: 'Target') -> Callable: def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
@ -222,7 +204,10 @@ def profile_function(target: 'Target') -> Callable:
inplace = kwargs.get('inplace', False) inplace = kwargs.get('inplace', False)
if inplace: if inplace:
kwargs['inplace'] = False kwargs['inplace'] = False
out, meta = _profile(func, *args, **kwargs) if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
if target in [torch.nn.functional.relu]: if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False meta.save_fwd_in = False
@ -234,7 +219,7 @@ def profile_function(target: 'Target') -> Callable:
return f return f
def profile_method(target: 'Target') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
@ -243,13 +228,16 @@ def profile_method(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result # execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.' assert isinstance(target, str), f'{target} instance is not str.'
out, meta = _profile(target, *args, **kwargs) if device == 'meta':
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
return out, meta return out, meta
return f return f
def profile_module(module: torch.nn.Module) -> Callable: def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
@ -271,7 +259,10 @@ def profile_module(module: torch.nn.Module) -> Callable:
inplace = getattr(module, 'inplace', False) inplace = getattr(module, 'inplace', False)
if inplace: if inplace:
module.inplace = False module.inplace = False
out, meta = _profile(func, *args, **kwargs) if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace: if inplace:
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU` # super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
# is the only inplace activation function that discard its input. # is the only inplace activation function that discard its input.

View File

@ -1,5 +1,9 @@
from copy import deepcopy
from typing import Optional, Union, overload
import torch import torch
from torch.utils._pytree import tree_map, tree_flatten from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device
from functools import singledispatchmethod
__all__ = ['MetaTensor'] __all__ = ['MetaTensor']
@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
@staticmethod @staticmethod
def __new__(cls, elem, fake_device=None): def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
if isinstance(elem, MetaTensor):
fake_device = elem.device if fake_device is None else fake_device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any # The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still # memory for the class in question, but it should still
# advertise the same device as before # advertise the same device as before
@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out) return tree_map(wrap, out)
@singledispatchmethod
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
return super().to(*args, **kwargs)
@to.register
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)
@to.register
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)