[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
pull/1617/head^2
Super Daniel 2022-09-23 10:59:47 +08:00 committed by GitHub
parent f7f2248771
commit d967779a32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 413 additions and 207 deletions

View File

@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input) return torch.empty_like(input)
@register_meta(aten.hardtanh.default)
def meta_hardtanh(input: torch.Tensor, min, max):
return torch.empty_like(input)
@register_meta(aten.hardswish_backward.default) @register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input) grad_in = torch.empty_like(input)
@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
@register_meta(aten.roll.default) @register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims): def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input) return input
@register_meta(aten.native_batch_norm.default) @register_meta(aten.native_batch_norm.default)
@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
return dX, dgamma, dbeta return dX, dgamma, dbeta
@register_meta(aten.native_layer_norm.default) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): @register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta') running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
return output, running_mean, running_var, reserve
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
return output, running_mean, running_var return output, running_mean, running_var
@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout) layout=grad_output.layout)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self) @register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition) result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
output = torch.empty_like(input)
mask = torch.empty_like(input, dtype=torch.bool)
return output, mask
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return torch.empty_like(grad)

View File

@ -2,6 +2,7 @@ from typing import List, Tuple
from torch.fx import Node from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math import math
from .linearize import linearize from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0 xbar = 0
for n in node: for n in node:
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out'] xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
return xbar return xbar
@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
def _get_deps_size(): def _get_deps_size():
deps_size = 0 deps_size = 0
for k, v in deps.items(): for k, v in deps.items():
k: Node
if v > 0: if v > 0:
deps_size += k.meta['bwd_mem_out'] deps_size += k.meta['bwd_mem_out']
if v == float('-inf'): if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out'] deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
return deps_size return deps_size
@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
""" """
node_list = linearize(gm, cnode) node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data) chain: Chain = _construct_chain(node_list, data)

View File

@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter):
tensor_meta = tree_map(extract_tensor_meta, result) tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future # TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result) n.meta['type'] = type(result)
# retain the autograd graph # retain the autograd graph
@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
""" """
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0])) return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args): def propagate(self, *args):
""" """

View File

@ -0,0 +1,78 @@
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = []
if META_COMPATIBILITY:
aten = torch.ops.aten
ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']

View File

@ -3,9 +3,6 @@ from enum import Enum
from typing import Dict from typing import Dict
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .memory import activation_size, is_inplace from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
class Phase(Enum): class Phase(Enum):
@ -23,29 +20,32 @@ class GraphInfo:
============================================================================ ============================================================================
------------------------------- -------------------------------
| Node | | Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
placeholders saved for | | \__________ | | placeholders saved for | | \__________ | |
backward. | | \ | | backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <----- | [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory | | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass. | / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | | in [fwd_tmp] because | | \_____ | |
it is not saved for | | | \ | | it is not saved for | | \ | |
backward. ------------------------------- backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
-------------------------------
============================================================================ ============================================================================
Attributes: Attributes:
fwd_flop (int): The forward FLOPs of a certain node fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node. bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration. save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_mem_tmp (int): See the above illustration. fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration. bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration. bwd_mem_out (int): See the above illustration.
""" """
fwd_flop: int = 0 fwd_flop: int = 0
bwd_flop: int = 0 bwd_flop: int = 0
fwd_mem_in: int = 0 save_fwd_in: bool = False
fwd_mem_tmp: int = 0 fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0 bwd_mem_tmp: int = 0
bwd_mem_out: int = 0 bwd_mem_out: int = 0
@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
def is_saved(n: Node): def is_saved(n: Node):
return n.meta.get('saved', False) return len(n.meta['saved_tensor'])
def autograd_graph_analysis(graph: Graph) -> GraphInfo: def autograd_graph_analysis(graph: Graph) -> GraphInfo:
@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]): def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0 peak_mem = 0
for k, v in deps.items(): for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)): if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['out']) peak_mem += activation_size(k.meta['saved_tensor'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN): if v <= float('-inf') and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['out']) peak_mem -= activation_size(k.meta['saved_tensor'])
return peak_mem return peak_mem
# deps is used to track all the memory dependencies of the graph. # deps is used to track all the memory dependencies of the graph.
@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes: for n in graph.nodes:
n: Node n: Node
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)): deps[n] = len(n.users)
# A forward tensor who is marked `save` but is not # A forward tensor who is marked `save` but is also
# an input to `loss` should be saved during forward. # an input to `Phase.FORWARD` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`. # If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function # Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed. # is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed. # the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER): if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out']) graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
if is_phase(n, Phase.FORWARD): if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out']) graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
elif is_phase(n, Phase.BACKWARD): elif is_phase(n, Phase.BACKWARD):
if len(n.users): if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else: else:
# TODO: some of the bwd_mem_out might be model parameters. # TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node # basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out']) graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
for input_n in n.all_input_nodes: for input_n in n.all_input_nodes:
if input_n in deps: if input_n in deps:
deps[input_n] -= 1 deps[input_n] -= 1

View File

@ -3,7 +3,8 @@ from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']

View File

@ -1,88 +1,10 @@
import torch import torch
from torch.fx import Node from torch.fx import Node
from typing import Union, Dict, List, Tuple from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size', 'is_inplace'] __all__ = ['activation_size', 'parameter_size', 'is_inplace']
if META_COMPATIBILITY:
aten = torch.ops.aten
WEIRD_OPS = [
torch.where,
]
INPLACE_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
# inplace reshaping
aten.copy_.default,
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]
NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node. """Calculate activation size of a node.
@ -106,13 +28,13 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
def parameter_size(mod: torch.nn.Module) -> int: def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate param size of a node. """Calculate parameter size of a node.
Args: Args:
mod (torch.nn.Module): The target `torch.nn.Module` mod (torch.nn.Module): The target `torch.nn.Module`
Returns: Returns:
int: The param size int: The parameter size
""" """
param_size = 0 param_size = 0
for param in mod.parameters(): for param in mod.parameters():
@ -132,8 +54,10 @@ def is_inplace(n: Node):
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY and n.target in INPLACE_ATEN: if META_COMPATIBILITY:
inplace = True from .constant import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module": elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)

View File

@ -1,7 +1,7 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw # ideas from https://pastebin.com/AkvAyJBw
from functools import reduce from functools import partial, reduce
import operator import operator
from typing import Callable, List, Any from typing import Callable, List, Any
from numbers import Number from numbers import Number
@ -147,8 +147,9 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
return norm_flop_jit return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
training = inputs[-3] if training is None:
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training: if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
@ -201,6 +202,8 @@ flop_mapping = {
# normalization # normalization
aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
@ -247,12 +250,14 @@ elementwise_flop_aten = [
aten.hardswish.default, aten.hardswish.default,
aten.hardswish_.default, aten.hardswish_.default,
aten.hardswish_backward.default, aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default, aten.hardtanh_.default,
aten.hardtanh_backward.default, aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default, aten.hardsigmoid_backward.default,
aten.hardsigmoid.default, aten.hardsigmoid.default,
aten.gelu.default, aten.gelu.default,
aten.gelu_backward.default, aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default, aten.silu_.default,
aten.silu_backward.default, aten.silu_backward.default,
aten.sigmoid.default, aten.sigmoid.default,
@ -264,6 +269,10 @@ elementwise_flop_aten = [
aten.tanh.default, aten.tanh.default,
aten.tanh_backward.default, aten.tanh_backward.default,
aten.threshold_backward.default, aten.threshold_backward.default,
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
] ]
for op in elementwise_flop_aten: for op in elementwise_flop_aten:

View File

@ -1,15 +1,21 @@
from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import GraphInfo, autograd_graph_analysis, Phase from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .memory import WEIRD_OPS from .memory import activation_size
from .constant import ALIAS_ATEN
from .tensor import MetaTensor from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()
def normalize_tuple(x): def normalize_tuple(x):
if not isinstance(x, tuple): if not isinstance(x, tuple):
@ -21,7 +27,17 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point() return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: # super-dainiu:
# x.detach() will change the unique identifier of data_ptr
# we need to handle this in a stupid way
def detach(x):
if isinstance(x, torch.Tensor):
requires_grad = x.requires_grad
x.requires_grad_(False)
x.requires_grad_(requires_grad)
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
""" """
Profile a Callable function with args and kwargs. Profile a Callable function with args and kwargs.
@ -55,8 +71,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
def __repr__(self): def __repr__(self):
if self.grad_fn: if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})" return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@ -68,27 +84,47 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
kwargs_node = tree_map(get_node, kwargs) kwargs_node = tree_map(get_node, kwargs)
node = subgraph.create_node('call_function', func, args_node, kwargs_node) node = subgraph.create_node('call_function', func, args_node, kwargs_node)
# do not allocate on `cpu` # do not allocate on physical devices
if 'device' in kwargs: if 'device' in kwargs:
kwargs['device'] = 'meta' fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
def unwrap(x): def unwrap(x):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` nonlocal fake_device
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): if isinstance(x, MetaTensor):
x = FlopTensor(x.to('meta')) fake_device = x.device
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x x = x._tensor
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta # run aten for backend=WHATEVER but actually on backend=Meta
out = func(*args, **kwargs) out = func(*args, **kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out)
node.meta['phase'] = phase node.meta['phase'] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
node.meta['phase'] = Phase.PLACEHOLDER
# TODO: specify `saved_tensors` for backward memory estimation
node.meta['saved_tensor'] = []
if phase == Phase.BACKWARD:
node.meta['saved_tensor'] = normalize_tuple(out)
def wrap(x): def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
def set_node(x): def set_node(x):
x._node = node x._node = node
@ -97,18 +133,13 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
tree_map(set_node, out) tree_map(set_node, out)
return out return out
# `WEIRD_OPS` are tough to handle because they don't accept autograd def wrap(x):
# on meta tensor. fake_device = None
if target not in WEIRD_OPS: if isinstance(x, MetaTensor):
fake_device = x.device
def wrap(x): x = x._tensor
return FlopTensor( detach(x)
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
else:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
# Basically, we need to detach the args and kwargs from the outer graph. # Basically, we need to detach the args and kwargs from the outer graph.
args = tree_map(wrap, args) args = tree_map(wrap, args)
@ -120,14 +151,16 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
'placeholder', (subgraph._root,), 'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor)) name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['out'] = (x._tensor,) x._node.meta['saved_tensor'] = []
tree_map(set_placeholder, args) tree_map(set_placeholder, args)
tree_map(set_placeholder, kwargs) tree_map(set_placeholder, kwargs)
def pack(x): def pack(x):
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter): global cache
x._node.meta['saved'] = True if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
x._node.meta['saved_tensor'] += [x._tensor]
cache.add(x._tensor.data_ptr)
return x return x
def unpack(x): def unpack(x):
@ -146,19 +179,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
# If the output is not a floating point `torch.Tensor` or it does not # If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node. # requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad: for tensor in normalize_tuple(out):
phase = Phase.BACKWARD if is_autogradable(tensor) and tensor.requires_grad:
if isinstance(out, FlopTensor): phase = Phase.BACKWARD
out._node.meta['save'] = False grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like( tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
out, device='meta') torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
torch.autograd.backward(out, FlopTensor(grad))
graph_info = autograd_graph_analysis(subgraph) graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
graph_info.fwd_mem_out = activation_size(out)
def unwrap(x): def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x if isinstance(x, FlopTensor):
fake_device = x.device
x = x._tensor
detach(x)
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(unwrap, out), graph_info return tree_map(unwrap, out), graph_info
@ -181,13 +218,15 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_function` is inplace, we should # If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling. # still run the profiling but discard some results regarding `target`
if kwargs.get('inplace', False): inplace = kwargs.get('inplace', False)
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) if inplace:
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) kwargs['inplace'] = False
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, **kwargs) out, meta = _profile(func, *args, **kwargs)
if inplace:
if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
return out, meta return out, meta
f.__name__ = target.__name__ f.__name__ = target.__name__
@ -228,13 +267,17 @@ def profile_module(module: torch.nn.Module) -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_module` is inplace, we should # If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling. # still run the profiling but discard some results regarding `module`.
if getattr(module, 'inplace', False): inplace = getattr(module, 'inplace', False)
args = tree_map(lambda x: x.to('meta'), args) if inplace:
kwargs = tree_map(lambda x: x.to('meta'), kwargs) module.inplace = False
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
out, meta = _profile(func, *args, **kwargs) out, meta = _profile(func, *args, **kwargs)
if inplace:
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
# is the only inplace activation function that discard its input.
if type(module) in [torch.nn.ReLU]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
return out, meta return out, meta
f.__name__ = module.__class__.__name__ f.__name__ = module.__class__.__name__

View File

@ -7,6 +7,7 @@ __all__ = ['MetaTensor']
class MetaTensor(torch.Tensor): class MetaTensor(torch.Tensor):
""" """
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
""" """
_tensor: torch.Tensor _tensor: torch.Tensor
@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor):
__slots__ = ['_tensor'] __slots__ = ['_tensor']
@staticmethod @staticmethod
def __new__(cls, elem): def __new__(cls, elem, fake_device=None):
# The wrapping tensor (MetaTensor) shouldn't hold any # The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still # memory for the class in question, but it should still
# advertise the same device as before # advertise the same device as before
@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, dtype=elem.dtype,
layout=elem.layout, layout=elem.layout,
device='cpu', device=fake_device if fake_device is not None else elem.device,
requires_grad=elem.requires_grad) # deceive the frontend for aten selections requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem r._tensor = elem
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
return r return r
def __repr__(self): def __repr__(self):
if self.grad_fn: if self.grad_fn:
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})" return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor})" return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
fake_device = None
def unwrap(x): def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): nonlocal fake_device
x = MetaTensor(x) if isinstance(x, MetaTensor):
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor):
# Now, we want to continue propagating this tensor, so we rewrap Tensors in # Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass # our custom tensor subclass
def wrap(x): def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out) return tree_map(wrap, out)

View File

@ -1,10 +1,21 @@
from colossalai.fx.profiler.memory import activation_size
import torch import torch
from torch.fx import Node, Graph from torch.fx import Node, Graph
from torch.fx.graph import _Namespace from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor """Trace forward and backward graph with MetaTensor
Args: Args:
@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
__slots__ = ['_tensor', '_node'] __slots__ = ['_tensor', '_node']
@staticmethod @staticmethod
def __new__(cls, tensor, placeholder=False, name=None): def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass( r = torch.Tensor._make_wrapper_subclass(
cls, cls,
tensor.size(), tensor.size(),
@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
storage_offset=tensor.storage_offset(), storage_offset=tensor.storage_offset(),
dtype=tensor.dtype, dtype=tensor.dtype,
layout=tensor.layout, layout=tensor.layout,
device='cpu', device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor r._tensor = tensor
if placeholder: if placeholder:
@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
'placeholder', (graph._root,), 'placeholder', (graph._root,),
name=namespace.create_name(name, tensor)) name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
return r return r
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x): def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): nonlocal fake_device
x = MetaProxy(x) if isinstance(x, MetaProxy):
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x fake_device = x.device
x = x._tensor
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
def get_node(x): def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
kwargs_node = tree_map(get_node, kwargs) kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node) node = graph.create_node('call_function', func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args) args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs) kwargs = tree_map(unwrap, kwargs)
@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
# Now, we want to continue propagating this tensor, so we rewrap Tensors in # Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass # our custom tensor subclass
def wrap(x): def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
def set_node(x): def set_node(x):
x._node = node x._node = node
@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
return out return out
def wrap(x): def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
args = tree_map(wrap, args) args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs) kwargs = tree_map(wrap, kwargs)
module(*args, **kwargs).sum().backward() out = module(*args, **kwargs)
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
return graph return graph

View File

@ -33,8 +33,9 @@ class MLP(torch.nn.Module):
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute(): def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM) model = MLP(MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
gm = symbolic_trace(model) gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample) MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)

View File

@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
x.requires_grad = requires_backward x.requires_grad = requires_backward
meta_x = MetaTensor(x.to('meta')) meta_x = MetaTensor(x)
if isinstance(f, nn.Module): x_out, meta_out = f(x), f(meta_x)
x_out, meta_out = f(x), f.to('meta')(meta_x)
else:
x_out, meta_out = f(x), f(meta_x)
compare_all(x_out, meta_out) compare_all(x_out, meta_out)
if requires_backward: if requires_backward:
x_out.sum().backward() x_out.sum().backward()

View File

@ -30,17 +30,17 @@ tmm_models = [
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models(): def test_torchvision_models():
for m in tm_models: for m in tm_models:
model = m().to('meta') model = m()
data = torch.rand(1000, 3, 224, 224, device='meta') data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward() model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models(): def test_timm_models():
for m in tmm_models: for m in tmm_models:
model = m().to('meta') model = m()
data = torch.rand(1000, 3, 224, 224, device='meta') data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward() model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,48 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
if META_COMPATIBILITY:
from colossalai.fx import meta_trace
tm_models = [
tm.vgg11,
tm.resnet18,
tm.densenet121,
tm.mobilenet_v3_small,
tm.resnext50_32x4d,
tm.wide_resnet50_2,
tm.regnet_x_16gf,
tm.mnasnet0_5,
tm.efficientnet_b0,
]
tmm_models = [
tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m,
tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224,
tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100,
tmm.swin_transformer.swin_base_patch4_window7_224
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace():
for m in tm_models:
model = m()
data = torch.rand(1000, 3, 224, 224, device='meta')
graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models_trace():
for m in tmm_models:
model = m()
data = torch.rand(1000, 3, 224, 224, device='meta')
graph = meta_trace(model, torch.device('cpu'), data)
if __name__ == '__main__':
test_torchvision_models_trace()
test_timm_models_trace()

View File

@ -1,12 +1,8 @@
import torch import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
import pytest
BATCH_SIZE = 2 BATCH_SIZE = 2
DIM_IN = 4 DIM_IN = 4
DIM_OUT = 16 DIM_OUT = 16
@ -22,6 +18,9 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop(): def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT) model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample) orig_output = model(input_sample)
gm = symbolic_trace(model) gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample) MetaInfoProp(gm).run(input_sample)