mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactor code for profiler / enable fake tensor movement. (#1646)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx] optimize table repr. * [fx] optimize table repr. * [fx] refactor code for profiler. * [fx] add docstring. * [fx] add docstring. * [fx] skip test. * [fx] redo. * [fx] redo. * [fx] fix import error for torch11. * [fx] fix import error for torch11.pull/1654/head
parent
5d0fdb9cb4
commit
6135e178b3
|
@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
|
|||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
|
@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
|
|||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain = _normalize_flops(chain, flops)
|
||||
|
|
|
@ -2,12 +2,12 @@ from typing import List, Tuple
|
|||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
|
@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
|
|
|
@ -14,6 +14,7 @@ if META_COMPATIBILITY:
|
|||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
INPLACE_NEW = [
|
||||
|
|
|
@ -37,9 +37,28 @@ def detach(x):
|
|||
x.requires_grad_(requires_grad)
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs.
|
||||
Profile a Callable function with args and kwargs on concrete devices.
|
||||
|
||||
Args:
|
||||
target (Callable): A Callable function
|
||||
args (Any): Argument
|
||||
kwargs (Any): Argument
|
||||
|
||||
Raises:
|
||||
NotImplementedError: TODO(yby)
|
||||
|
||||
Returns:
|
||||
out (Tuple[Any, ...]): The argument value that was retrieved.
|
||||
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs on meta devices.
|
||||
|
||||
Args:
|
||||
target (Callable): A Callable function
|
||||
|
@ -67,7 +86,7 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
|||
# Hopefully, this attempt will provide a better estimation of memory.
|
||||
class FlopTensor(MetaTensor):
|
||||
|
||||
_node: Node
|
||||
_node: Node = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
|
@ -76,34 +95,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
|||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def get_node(x):
|
||||
return None if not hasattr(x, '_node') else x._node
|
||||
|
||||
args_node = tree_map(get_node, args)
|
||||
kwargs_node = tree_map(get_node, kwargs)
|
||||
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
|
||||
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
|
||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
# do not allocate on physical devices
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
out = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
def unwrap(x):
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=WHATEVER but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||
node.meta['phase'] = phase
|
||||
|
||||
|
@ -114,52 +111,41 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
|||
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
||||
node.meta['phase'] = Phase.PLACEHOLDER
|
||||
|
||||
# TODO: specify `saved_tensors` for backward memory estimation
|
||||
# TODO(yby): specify `saved_tensors` for backward memory estimation
|
||||
node.meta['saved_tensor'] = []
|
||||
if phase == Phase.BACKWARD:
|
||||
node.meta['saved_tensor'] = normalize_tuple(out)
|
||||
|
||||
def wrap(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
if isinstance(x, MetaTensor):
|
||||
x = FlopTensor(x)
|
||||
x._node = node
|
||||
return x
|
||||
|
||||
out = tree_map(wrap, out)
|
||||
tree_map(set_node, out)
|
||||
return out
|
||||
|
||||
def wrap(x):
|
||||
fake_device = None
|
||||
if isinstance(x, MetaTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
detach(x)
|
||||
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
|
||||
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
def set_placeholder(x):
|
||||
if isinstance(x, FlopTensor):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = FlopTensor(x)
|
||||
if is_autogradable(x):
|
||||
x.requires_grad_(True)
|
||||
x._node = subgraph.create_node('placeholder',
|
||||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||
x._node.meta['saved_tensor'] = []
|
||||
detach(x)
|
||||
return x
|
||||
|
||||
tree_map(set_placeholder, args)
|
||||
tree_map(set_placeholder, kwargs)
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
def pack(x):
|
||||
global cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
||||
x._node.meta['saved_tensor'] += [x._tensor]
|
||||
x._node.meta['saved_tensor'] += [x]
|
||||
cache.add(x._tensor.data_ptr)
|
||||
return x
|
||||
|
||||
|
@ -191,16 +177,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
|||
graph_info.fwd_mem_out = activation_size(out)
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, FlopTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
detach(x)
|
||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(unwrap, out), graph_info
|
||||
|
||||
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
@ -222,7 +204,10 @@ def profile_function(target: 'Target') -> Callable:
|
|||
inplace = kwargs.get('inplace', False)
|
||||
if inplace:
|
||||
kwargs['inplace'] = False
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
if target in [torch.nn.functional.relu]:
|
||||
meta.save_fwd_in = False
|
||||
|
@ -234,7 +219,7 @@ def profile_function(target: 'Target') -> Callable:
|
|||
return f
|
||||
|
||||
|
||||
def profile_method(target: 'Target') -> Callable:
|
||||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
@ -243,13 +228,16 @@ def profile_method(target: 'Target') -> Callable:
|
|||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
out, meta = _profile(target, *args, **kwargs)
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(target, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(target, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def profile_module(module: torch.nn.Module) -> Callable:
|
||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
@ -271,7 +259,10 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
inplace = getattr(module, 'inplace', False)
|
||||
if inplace:
|
||||
module.inplace = False
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
|
||||
# is the only inplace activation function that discard its input.
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from copy import deepcopy
|
||||
from typing import Optional, Union, overload
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.types import _bool, _dtype, _device
|
||||
from functools import singledispatchmethod
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
|
|||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, fake_device=None):
|
||||
# Avoid multiple wrapping
|
||||
if isinstance(elem, MetaTensor):
|
||||
fake_device = elem.device if fake_device is None else fake_device
|
||||
elem = elem._tensor
|
||||
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
|
@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
|
|||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
|
||||
@singledispatchmethod
|
||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||
|
||||
Returns:
|
||||
result (MetaTensor): MetaTensor
|
||||
|
||||
Usage:
|
||||
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
|
||||
>>> tensor.to(torch.uint8)
|
||||
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
|
||||
>>> tensor.to(torch.device('cuda:42'))
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
|
||||
>>> tensor.to('vulkan')
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
@to.register
|
||||
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
|
||||
@to.register
|
||||
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
|
|
Loading…
Reference in New Issue