[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
pull/1617/head^2
Super Daniel 2022-09-23 10:59:47 +08:00 committed by GitHub
parent f7f2248771
commit d967779a32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 413 additions and 207 deletions

View File

@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input)
@register_meta(aten.hardtanh.default)
def meta_hardtanh(input: torch.Tensor, min, max):
return torch.empty_like(input)
@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
return input
@register_meta(aten.native_batch_norm.default)
@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
return dX, dgamma, dbeta
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
return output, running_mean, running_var, reserve
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
return output, running_mean, running_var
@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition)
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
output = torch.empty_like(input)
mask = torch.empty_like(input, dtype=torch.bool)
return output, mask
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return torch.empty_like(grad)

View File

@ -2,6 +2,7 @@ from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0
for n in node:
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
return xbar
@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
def _get_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
return deps_size
@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
"""
node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)

View File

@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter):
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result)
# retain the autograd graph
@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
"""

View File

@ -0,0 +1,78 @@
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = []
if META_COMPATIBILITY:
aten = torch.ops.aten
ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']

View File

@ -3,9 +3,6 @@ from enum import Enum
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
class Phase(Enum):
@ -23,29 +20,32 @@ class GraphInfo:
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
in [fwd_tmp] because | | \_____ | |
it is not saved for | | \ | |
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
-------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
save_fwd_in: bool = False
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0
@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
def is_saved(n: Node):
return n.meta.get('saved', False)
return len(n.meta['saved_tensor'])
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['saved_tensor'])
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['saved_tensor'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes:
n: Node
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
deps[n] = len(n.users)
# A forward tensor who is marked `save` but is also
# an input to `Phase.FORWARD` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1

View File

@ -3,7 +3,8 @@ from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
__all__ = ['profile_function', 'profile_module', 'profile_method']

View File

@ -1,88 +1,10 @@
import torch
from torch.fx import Node
from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
if META_COMPATIBILITY:
aten = torch.ops.aten
WEIRD_OPS = [
torch.where,
]
INPLACE_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
# inplace reshaping
aten.copy_.default,
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]
NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
@ -106,13 +28,13 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate param size of a node.
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The param size
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
@ -132,8 +54,10 @@ def is_inplace(n: Node):
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
inplace = True
if META_COMPATIBILITY:
from .constant import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)

View File

@ -1,7 +1,7 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
from functools import reduce
from functools import partial, reduce
import operator
from typing import Callable, List, Any
from numbers import Number
@ -147,8 +147,9 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
training = inputs[-3]
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
if training is None:
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
@ -201,6 +202,8 @@ flop_mapping = {
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
@ -247,12 +250,14 @@ elementwise_flop_aten = [
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
@ -264,6 +269,10 @@ elementwise_flop_aten = [
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
for op in elementwise_flop_aten:

View File

@ -1,15 +1,21 @@
from functools import partial
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .memory import activation_size
from .constant import ALIAS_ATEN
from .tensor import MetaTensor
from .opcount import flop_mapping
__all__ = ['profile_function', 'profile_module', 'profile_method']
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()
def normalize_tuple(x):
if not isinstance(x, tuple):
@ -21,7 +27,17 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
# super-dainiu:
# x.detach() will change the unique identifier of data_ptr
# we need to handle this in a stupid way
def detach(x):
if isinstance(x, torch.Tensor):
requires_grad = x.requires_grad
x.requires_grad_(False)
x.requires_grad_(requires_grad)
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs.
@ -55,8 +71,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
def __repr__(self):
if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@ -68,27 +84,47 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on `cpu`
# do not allocate on physical devices
if 'device' in kwargs:
kwargs['device'] = 'meta'
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = FlopTensor(x.to('meta'))
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
# run aten for backend=WHATEVER but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out)
node.meta['phase'] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER
# TODO: specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = []
if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out)
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
def set_node(x):
x._node = node
@ -97,18 +133,13 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
tree_map(set_node, out)
return out
# `WEIRD_OPS` are tough to handle because they don't accept autograd
# on meta tensor.
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
def wrap(x):
fake_device = None
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
detach(x)
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
# Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args)
@ -120,14 +151,16 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['out'] = (x._tensor,)
x._node.meta['saved_tensor'] = []
tree_map(set_placeholder, args)
tree_map(set_placeholder, kwargs)
def pack(x):
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
x._node.meta['saved'] = True
global cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
x._node.meta['saved_tensor'] += [x._tensor]
cache.add(x._tensor.data_ptr)
return x
def unpack(x):
@ -146,19 +179,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
phase = Phase.BACKWARD
if isinstance(out, FlopTensor):
out._node.meta['save'] = False
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
out, device='meta')
torch.autograd.backward(out, FlopTensor(grad))
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
phase = Phase.BACKWARD
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
graph_info.fwd_mem_out = activation_size(out)
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
if isinstance(x, FlopTensor):
fake_device = x.device
x = x._tensor
detach(x)
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info
@ -181,13 +218,15 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling.
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
# still run the profiling but discard some results regarding `target`
inplace = kwargs.get('inplace', False)
if inplace:
kwargs['inplace'] = False
out, meta = _profile(func, *args, **kwargs)
if inplace:
if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
return out, meta
f.__name__ = target.__name__
@ -228,13 +267,17 @@ def profile_module(module: torch.nn.Module) -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling.
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
# still run the profiling but discard some results regarding `module`.
inplace = getattr(module, 'inplace', False)
if inplace:
module.inplace = False
out, meta = _profile(func, *args, **kwargs)
if inplace:
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
# is the only inplace activation function that discard its input.
if type(module) in [torch.nn.ReLU]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
return out, meta
f.__name__ = module.__class__.__name__

View File

@ -7,6 +7,7 @@ __all__ = ['MetaTensor']
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
"""
_tensor: torch.Tensor
@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor):
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem):
def __new__(cls, elem, fake_device=None):
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device='cpu',
device=fake_device if fake_device is not None else elem.device,
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
return r
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
fake_device = None
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaTensor(x)
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor):
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)

View File

@ -1,10 +1,21 @@
from colossalai.fx.profiler.memory import activation_size
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor
Args:
@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
__slots__ = ['_tensor', '_node']
@staticmethod
def __new__(cls, tensor, placeholder=False, name=None):
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device='cpu',
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaProxy(x)
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x
nonlocal fake_device
if isinstance(x, MetaProxy):
fake_device = x.device
x = x._tensor
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
def set_node(x):
x._node = node
@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
return out
def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
module(*args, **kwargs).sum().backward()
out = module(*args, **kwargs)
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
return graph

View File

@ -33,8 +33,9 @@ class MLP(torch.nn.Module):
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)

View File

@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
x.requires_grad = requires_backward
meta_x = MetaTensor(x.to('meta'))
if isinstance(f, nn.Module):
x_out, meta_out = f(x), f.to('meta')(meta_x)
else:
x_out, meta_out = f(x), f(meta_x)
meta_x = MetaTensor(x)
x_out, meta_out = f(x), f(meta_x)
compare_all(x_out, meta_out)
if requires_backward:
x_out.sum().backward()

View File

@ -30,17 +30,17 @@ tmm_models = [
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
model = m()
data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
model = m()
data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
if __name__ == '__main__':

View File

@ -0,0 +1,48 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
if META_COMPATIBILITY:
from colossalai.fx import meta_trace
tm_models = [
tm.vgg11,
tm.resnet18,
tm.densenet121,
tm.mobilenet_v3_small,
tm.resnext50_32x4d,
tm.wide_resnet50_2,
tm.regnet_x_16gf,
tm.mnasnet0_5,
tm.efficientnet_b0,
]
tmm_models = [
tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m,
tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224,
tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100,
tmm.swin_transformer.swin_base_patch4_window7_224
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace():
for m in tm_models:
model = m()
data = torch.rand(1000, 3, 224, 224, device='meta')
graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models_trace():
for m in tmm_models:
model = m()
data = torch.rand(1000, 3, 224, 224, device='meta')
graph = meta_trace(model, torch.device('cpu'), data)
if __name__ == '__main__':
test_torchvision_models_trace()
test_timm_models_trace()

View File

@ -1,12 +1,8 @@
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
import pytest
BATCH_SIZE = 2
DIM_IN = 4
DIM_OUT = 16
@ -22,6 +18,9 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)