mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactor code for profiler / enable fake tensor movement. (#1646)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx] optimize table repr. * [fx] optimize table repr. * [fx] refactor code for profiler. * [fx] add docstring. * [fx] add docstring. * [fx] skip test. * [fx] redo. * [fx] redo. * [fx] fix import error for torch11. * [fx] fix import error for torch11.pull/1654/head
parent
5d0fdb9cb4
commit
6135e178b3
|
@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
|
|
||||||
INF = float("inf")
|
INF = float("inf")
|
||||||
|
|
||||||
|
@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
|
||||||
mem_limit -= parameter_size(gm)
|
mem_limit -= parameter_size(gm)
|
||||||
|
|
||||||
# prepare data
|
# prepare data
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
chain: Chain = _construct_chain(node_list, data)
|
chain: Chain = _construct_chain(node_list, data)
|
||||||
chain = _normalize_flops(chain, flops)
|
chain = _normalize_flops(chain, flops)
|
||||||
|
|
|
@ -2,12 +2,12 @@ from typing import List, Tuple
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
from colossalai.fx.profiler.tensor import MetaTensor
|
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from .linearize import linearize
|
||||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
|
|
||||||
|
|
||||||
# this is the python compute table code from rotor
|
# this is the python compute table code from rotor
|
||||||
|
@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
|
||||||
|
|
||||||
node_list = linearize(gm, cnode)
|
node_list = linearize(gm, cnode)
|
||||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
|
|
||||||
chain: Chain = _construct_chain(node_list, data)
|
chain: Chain = _construct_chain(node_list, data)
|
||||||
|
|
|
@ -14,6 +14,7 @@ if META_COMPATIBILITY:
|
||||||
aten.transpose.int,
|
aten.transpose.int,
|
||||||
aten.view.default,
|
aten.view.default,
|
||||||
aten._unsafe_view.default,
|
aten._unsafe_view.default,
|
||||||
|
aten._reshape_alias.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
INPLACE_NEW = [
|
INPLACE_NEW = [
|
||||||
|
|
|
@ -37,9 +37,28 @@ def detach(x):
|
||||||
x.requires_grad_(requires_grad)
|
x.requires_grad_(requires_grad)
|
||||||
|
|
||||||
|
|
||||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||||
"""
|
"""
|
||||||
Profile a Callable function with args and kwargs.
|
Profile a Callable function with args and kwargs on concrete devices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Callable): A Callable function
|
||||||
|
args (Any): Argument
|
||||||
|
kwargs (Any): Argument
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: TODO(yby)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out (Tuple[Any, ...]): The argument value that was retrieved.
|
||||||
|
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||||
|
"""
|
||||||
|
Profile a Callable function with args and kwargs on meta devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target (Callable): A Callable function
|
target (Callable): A Callable function
|
||||||
|
@ -67,7 +86,7 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
||||||
# Hopefully, this attempt will provide a better estimation of memory.
|
# Hopefully, this attempt will provide a better estimation of memory.
|
||||||
class FlopTensor(MetaTensor):
|
class FlopTensor(MetaTensor):
|
||||||
|
|
||||||
_node: Node
|
_node: Node = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.grad_fn:
|
if self.grad_fn:
|
||||||
|
@ -76,34 +95,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
|
||||||
def get_node(x):
|
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
|
||||||
return None if not hasattr(x, '_node') else x._node
|
|
||||||
|
|
||||||
args_node = tree_map(get_node, args)
|
|
||||||
kwargs_node = tree_map(get_node, kwargs)
|
|
||||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||||
|
|
||||||
# do not allocate on physical devices
|
out = super().__torch_dispatch__(func, types, args, kwargs)
|
||||||
if 'device' in kwargs:
|
|
||||||
fake_device = kwargs['device']
|
|
||||||
kwargs['device'] = torch.device('meta')
|
|
||||||
|
|
||||||
def unwrap(x):
|
|
||||||
nonlocal fake_device
|
|
||||||
if isinstance(x, MetaTensor):
|
|
||||||
fake_device = x.device
|
|
||||||
x = x._tensor
|
|
||||||
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
|
||||||
fake_device = x.device
|
|
||||||
x = x.to(torch.device('meta'))
|
|
||||||
return x
|
|
||||||
|
|
||||||
args = tree_map(unwrap, args)
|
|
||||||
kwargs = tree_map(unwrap, kwargs)
|
|
||||||
|
|
||||||
# run aten for backend=WHATEVER but actually on backend=Meta
|
|
||||||
out = func(*args, **kwargs)
|
|
||||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||||
node.meta['phase'] = phase
|
node.meta['phase'] = phase
|
||||||
|
|
||||||
|
@ -114,52 +111,41 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
||||||
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
||||||
node.meta['phase'] = Phase.PLACEHOLDER
|
node.meta['phase'] = Phase.PLACEHOLDER
|
||||||
|
|
||||||
# TODO: specify `saved_tensors` for backward memory estimation
|
# TODO(yby): specify `saved_tensors` for backward memory estimation
|
||||||
node.meta['saved_tensor'] = []
|
node.meta['saved_tensor'] = []
|
||||||
if phase == Phase.BACKWARD:
|
if phase == Phase.BACKWARD:
|
||||||
node.meta['saved_tensor'] = normalize_tuple(out)
|
node.meta['saved_tensor'] = normalize_tuple(out)
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, MetaTensor):
|
||||||
nonlocal fake_device
|
x = FlopTensor(x)
|
||||||
if not x.is_meta:
|
x._node = node
|
||||||
x = x.to(torch.device('meta'))
|
return x
|
||||||
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
|
||||||
|
|
||||||
def set_node(x):
|
|
||||||
x._node = node
|
|
||||||
|
|
||||||
out = tree_map(wrap, out)
|
out = tree_map(wrap, out)
|
||||||
tree_map(set_node, out)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
fake_device = None
|
if isinstance(x, torch.Tensor):
|
||||||
if isinstance(x, MetaTensor):
|
x = FlopTensor(x)
|
||||||
fake_device = x.device
|
if is_autogradable(x):
|
||||||
x = x._tensor
|
x.requires_grad_(True)
|
||||||
detach(x)
|
|
||||||
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
|
|
||||||
|
|
||||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
|
||||||
args = tree_map(wrap, args)
|
|
||||||
kwargs = tree_map(wrap, kwargs)
|
|
||||||
|
|
||||||
def set_placeholder(x):
|
|
||||||
if isinstance(x, FlopTensor):
|
|
||||||
x._node = subgraph.create_node('placeholder',
|
x._node = subgraph.create_node('placeholder',
|
||||||
'placeholder', (subgraph._root,),
|
'placeholder', (subgraph._root,),
|
||||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||||
x._node.meta['saved_tensor'] = []
|
x._node.meta['saved_tensor'] = []
|
||||||
|
detach(x)
|
||||||
|
return x
|
||||||
|
|
||||||
tree_map(set_placeholder, args)
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||||
tree_map(set_placeholder, kwargs)
|
args = tree_map(wrap, args)
|
||||||
|
kwargs = tree_map(wrap, kwargs)
|
||||||
|
|
||||||
def pack(x):
|
def pack(x):
|
||||||
global cache
|
global cache
|
||||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
||||||
x._node.meta['saved_tensor'] += [x._tensor]
|
x._node.meta['saved_tensor'] += [x]
|
||||||
cache.add(x._tensor.data_ptr)
|
cache.add(x._tensor.data_ptr)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -191,16 +177,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI
|
||||||
graph_info.fwd_mem_out = activation_size(out)
|
graph_info.fwd_mem_out = activation_size(out)
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
if isinstance(x, FlopTensor):
|
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||||
fake_device = x.device
|
|
||||||
x = x._tensor
|
|
||||||
detach(x)
|
|
||||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
|
||||||
|
|
||||||
return tree_map(unwrap, out), graph_info
|
return tree_map(unwrap, out), graph_info
|
||||||
|
|
||||||
|
|
||||||
def profile_function(target: 'Target') -> Callable:
|
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||||
record the memory cost and FLOPs of the execution.
|
record the memory cost and FLOPs of the execution.
|
||||||
|
@ -222,7 +204,10 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
inplace = kwargs.get('inplace', False)
|
inplace = kwargs.get('inplace', False)
|
||||||
if inplace:
|
if inplace:
|
||||||
kwargs['inplace'] = False
|
kwargs['inplace'] = False
|
||||||
out, meta = _profile(func, *args, **kwargs)
|
if device == 'meta':
|
||||||
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
if inplace:
|
if inplace:
|
||||||
if target in [torch.nn.functional.relu]:
|
if target in [torch.nn.functional.relu]:
|
||||||
meta.save_fwd_in = False
|
meta.save_fwd_in = False
|
||||||
|
@ -234,7 +219,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def profile_method(target: 'Target') -> Callable:
|
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_method` node
|
Wrap a `call_method` node
|
||||||
record the memory cost and FLOPs of the execution.
|
record the memory cost and FLOPs of the execution.
|
||||||
|
@ -243,13 +228,16 @@ def profile_method(target: 'Target') -> Callable:
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
# execute the method and return the result
|
# execute the method and return the result
|
||||||
assert isinstance(target, str), f'{target} instance is not str.'
|
assert isinstance(target, str), f'{target} instance is not str.'
|
||||||
out, meta = _profile(target, *args, **kwargs)
|
if device == 'meta':
|
||||||
|
out, meta = _profile_meta(target, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
out, meta = _profile_concrete(target, *args, **kwargs)
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def profile_module(module: torch.nn.Module) -> Callable:
|
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_module` node or `torch.nn` in order to
|
Wrap a `call_module` node or `torch.nn` in order to
|
||||||
record the memory cost and FLOPs of the execution.
|
record the memory cost and FLOPs of the execution.
|
||||||
|
@ -271,7 +259,10 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
inplace = getattr(module, 'inplace', False)
|
inplace = getattr(module, 'inplace', False)
|
||||||
if inplace:
|
if inplace:
|
||||||
module.inplace = False
|
module.inplace = False
|
||||||
out, meta = _profile(func, *args, **kwargs)
|
if device == 'meta':
|
||||||
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
if inplace:
|
if inplace:
|
||||||
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
|
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
|
||||||
# is the only inplace activation function that discard its input.
|
# is the only inplace activation function that discard its input.
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional, Union, overload
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map, tree_flatten
|
from torch.utils._pytree import tree_map, tree_flatten
|
||||||
|
from torch.types import _bool, _dtype, _device
|
||||||
|
from functools import singledispatchmethod
|
||||||
|
|
||||||
__all__ = ['MetaTensor']
|
__all__ = ['MetaTensor']
|
||||||
|
|
||||||
|
@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, elem, fake_device=None):
|
def __new__(cls, elem, fake_device=None):
|
||||||
|
# Avoid multiple wrapping
|
||||||
|
if isinstance(elem, MetaTensor):
|
||||||
|
fake_device = elem.device if fake_device is None else fake_device
|
||||||
|
elem = elem._tensor
|
||||||
|
|
||||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||||
# memory for the class in question, but it should still
|
# memory for the class in question, but it should still
|
||||||
# advertise the same device as before
|
# advertise the same device as before
|
||||||
|
@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
|
||||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
return tree_map(wrap, out)
|
return tree_map(wrap, out)
|
||||||
|
|
||||||
|
@singledispatchmethod
|
||||||
|
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (MetaTensor): MetaTensor
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
|
||||||
|
>>> tensor.to(torch.uint8)
|
||||||
|
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
|
||||||
|
>>> tensor.to(torch.device('cuda:42'))
|
||||||
|
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
|
||||||
|
>>> tensor.to('vulkan')
|
||||||
|
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||||
|
"""
|
||||||
|
# this imitates c++ function in the way of @overload
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
|
@to.register
|
||||||
|
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||||
|
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||||
|
return MetaTensor(deepcopy(result), fake_device=device)
|
||||||
|
|
||||||
|
@to.register
|
||||||
|
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||||
|
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||||
|
return MetaTensor(deepcopy(result), fake_device=device)
|
||||||
|
|
Loading…
Reference in New Issue