mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fxpull/1617/head^2
parent
f7f2248771
commit
d967779a32
|
@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor):
|
|||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.hardtanh.default)
|
||||
def meta_hardtanh(input: torch.Tensor, min, max):
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.hardswish_backward.default)
|
||||
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
|
||||
grad_in = torch.empty_like(input)
|
||||
|
@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
|
|||
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return torch.empty_like(input)
|
||||
return input
|
||||
|
||||
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
|
@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
|||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((n_input), device='meta')
|
||||
running_var = torch.empty((n_input), device='meta')
|
||||
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
|
||||
return output, running_mean, running_var, reserve
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(weight)
|
||||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs = input.size(0)
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((bs, n_input, 1), device='meta')
|
||||
running_var = torch.empty((bs, n_input, 1), device='meta')
|
||||
return output, running_mean, running_var
|
||||
|
||||
|
||||
|
@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
|
|||
layout=grad_output.layout)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
return torch.empty_like(condition)
|
||||
result_type = torch.result_type(self, other)
|
||||
return torch.empty_like(self, dtype=result_type)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
output = torch.empty_like(input)
|
||||
mask = torch.empty_like(input, dtype=torch.bool)
|
||||
return output, mask
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return torch.empty_like(grad)
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import List, Tuple
|
|||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
|
@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
|
|||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
|
||||
xbar += n.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
return xbar
|
||||
|
||||
|
||||
|
@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
|||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
|
||||
deps_size -= k.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
|
||||
deps_size -= k.meta['fwd_mem_out']
|
||||
|
||||
return deps_size
|
||||
|
||||
|
@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
|
|
|
@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
for par in n.all_input_nodes:
|
||||
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
|
@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
result (Any): The argument value that was retrieved
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
|
||||
return args[0], GraphInfo(save_fwd_in=True)
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from . import META_COMPATIBILITY
|
||||
|
||||
__all__ = []
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
aten = torch.ops.aten
|
||||
|
||||
ALIAS_ATEN = [
|
||||
# inplace reshaping
|
||||
aten.detach.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
]
|
||||
|
||||
INPLACE_NEW = [
|
||||
aten.empty_like.default,
|
||||
aten.new_empty_strided.default,
|
||||
]
|
||||
|
||||
INPLACE_MATH_ATEN = [
|
||||
aten.add_.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.bernoulli_.float,
|
||||
]
|
||||
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
||||
|
||||
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||
|
||||
else:
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
'type',
|
||||
'flatten',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'chunk',
|
||||
'contiguous',
|
||||
'expand',
|
||||
'mean',
|
||||
'split',
|
||||
]
|
||||
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
|
@ -3,9 +3,6 @@ from enum import Enum
|
|||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size, is_inplace
|
||||
from . import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
|
@ -23,29 +20,32 @@ class GraphInfo:
|
|||
============================================================================
|
||||
-------------------------------
|
||||
| Node |
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
|
||||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | | \_____ | |
|
||||
it is not saved for | | | \ | |
|
||||
backward. -------------------------------
|
||||
in [fwd_tmp] because | | \_____ | |
|
||||
it is not saved for | | \ | |
|
||||
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
|
||||
-------------------------------
|
||||
============================================================================
|
||||
Attributes:
|
||||
fwd_flop (int): The forward FLOPs of a certain node
|
||||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
fwd_mem_in (int): See the above illustration.
|
||||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
fwd_mem_out (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
fwd_flop: int = 0
|
||||
bwd_flop: int = 0
|
||||
fwd_mem_in: int = 0
|
||||
save_fwd_in: bool = False
|
||||
fwd_mem_tmp: int = 0
|
||||
fwd_mem_out: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
|
@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
|||
|
||||
|
||||
def is_saved(n: Node):
|
||||
return n.meta.get('saved', False)
|
||||
return len(n.meta['saved_tensor'])
|
||||
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
|
@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
def _peak_memory(deps: Dict[Node, int]):
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
||||
peak_mem -= activation_size(k.meta['out'])
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
|
||||
peak_mem += activation_size(k.meta['saved_tensor'])
|
||||
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
|
||||
peak_mem -= activation_size(k.meta['saved_tensor'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
|
@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
deps[n] = len(n.users)
|
||||
# A forward tensor who is marked `save` but is also
|
||||
# an input to `Phase.FORWARD` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
# Any `fwd_mem_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
else:
|
||||
# TODO: some of the bwd_mem_out might be model parameters.
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
|
|
|
@ -3,7 +3,8 @@ from typing import Callable, Any, Dict, Tuple
|
|||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
||||
from ..memory import activation_size
|
||||
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
|
|
@ -1,88 +1,10 @@
|
|||
import torch
|
||||
from torch.fx import Node
|
||||
from typing import Union, Dict, List, Tuple
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from . import META_COMPATIBILITY
|
||||
|
||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
aten = torch.ops.aten
|
||||
|
||||
WEIRD_OPS = [
|
||||
torch.where,
|
||||
]
|
||||
|
||||
INPLACE_ATEN = [
|
||||
aten.add_.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# inplace reshaping
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
]
|
||||
|
||||
NORMALIZATION_ATEN = [
|
||||
aten.native_batch_norm.default,
|
||||
aten.native_layer_norm.default,
|
||||
# aten.max_pool2d_with_indices.default,
|
||||
]
|
||||
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
||||
|
||||
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
|
||||
|
||||
else:
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
'type',
|
||||
'flatten',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'chunk',
|
||||
'contiguous',
|
||||
'expand',
|
||||
'mean',
|
||||
'split',
|
||||
]
|
||||
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
||||
|
||||
|
||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
@ -106,13 +28,13 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|||
|
||||
|
||||
def parameter_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate param size of a node.
|
||||
"""Calculate parameter size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
|
||||
Returns:
|
||||
int: The param size
|
||||
int: The parameter size
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
|
@ -132,7 +54,9 @@ def is_inplace(n: Node):
|
|||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
|
||||
if META_COMPATIBILITY:
|
||||
from .constant import ALIAS_ATEN
|
||||
if n.target in ALIAS_ATEN:
|
||||
inplace = True
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
||||
# ideas from https://pastebin.com/AkvAyJBw
|
||||
|
||||
from functools import reduce
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Callable, List, Any
|
||||
from numbers import Number
|
||||
|
@ -147,7 +147,8 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
|
|||
return norm_flop_jit
|
||||
|
||||
|
||||
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
|
||||
if training is None:
|
||||
training = inputs[-3]
|
||||
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
||||
if training:
|
||||
|
@ -201,6 +202,8 @@ flop_mapping = {
|
|||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
|
@ -247,12 +250,14 @@ elementwise_flop_aten = [
|
|||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
|
@ -264,6 +269,10 @@ elementwise_flop_aten = [
|
|||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
]
|
||||
|
||||
for op in elementwise_flop_aten:
|
||||
|
|
|
@ -1,15 +1,21 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
||||
from .memory import WEIRD_OPS
|
||||
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
|
||||
from .memory import activation_size
|
||||
from .constant import ALIAS_ATEN
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
# super-dainiu: this cache should be global, otherwise it cannot
|
||||
# track duplicated tensors between nodes
|
||||
cache = set()
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
|
@ -21,7 +27,17 @@ def is_autogradable(x):
|
|||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||
# super-dainiu:
|
||||
# x.detach() will change the unique identifier of data_ptr
|
||||
# we need to handle this in a stupid way
|
||||
def detach(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
requires_grad = x.requires_grad
|
||||
x.requires_grad_(False)
|
||||
x.requires_grad_(requires_grad)
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs.
|
||||
|
||||
|
@ -55,8 +71,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
|
||||
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
@ -68,27 +84,47 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
# do not allocate on `cpu`
|
||||
# do not allocate on physical devices
|
||||
if 'device' in kwargs:
|
||||
kwargs['device'] = 'meta'
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
def unwrap(x):
|
||||
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = FlopTensor(x.to('meta'))
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
# run aten for backend=WHATEVER but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||
node.meta['out'] = normalize_tuple(out)
|
||||
node.meta['phase'] = phase
|
||||
|
||||
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
|
||||
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
|
||||
# `Phase.FORWARD`
|
||||
if phase == Phase.FORWARD:
|
||||
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
||||
node.meta['phase'] = Phase.PLACEHOLDER
|
||||
|
||||
# TODO: specify `saved_tensors` for backward memory estimation
|
||||
node.meta['saved_tensor'] = []
|
||||
if phase == Phase.BACKWARD:
|
||||
node.meta['saved_tensor'] = normalize_tuple(out)
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
|
@ -97,18 +133,13 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
tree_map(set_node, out)
|
||||
return out
|
||||
|
||||
# `WEIRD_OPS` are tough to handle because they don't accept autograd
|
||||
# on meta tensor.
|
||||
if target not in WEIRD_OPS:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
else:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
fake_device = None
|
||||
if isinstance(x, MetaTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
detach(x)
|
||||
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
|
||||
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
args = tree_map(wrap, args)
|
||||
|
@ -120,14 +151,16 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||
x._node.meta['out'] = (x._tensor,)
|
||||
x._node.meta['saved_tensor'] = []
|
||||
|
||||
tree_map(set_placeholder, args)
|
||||
tree_map(set_placeholder, kwargs)
|
||||
|
||||
def pack(x):
|
||||
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
||||
x._node.meta['saved'] = True
|
||||
global cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
||||
x._node.meta['saved_tensor'] += [x._tensor]
|
||||
cache.add(x._tensor.data_ptr)
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
|
@ -146,19 +179,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
|||
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
for tensor in normalize_tuple(out):
|
||||
if is_autogradable(tensor) and tensor.requires_grad:
|
||||
phase = Phase.BACKWARD
|
||||
if isinstance(out, FlopTensor):
|
||||
out._node.meta['save'] = False
|
||||
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
||||
out, device='meta')
|
||||
torch.autograd.backward(out, FlopTensor(grad))
|
||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
|
||||
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||
graph_info.fwd_mem_out = activation_size(out)
|
||||
|
||||
def unwrap(x):
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
if isinstance(x, FlopTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
detach(x)
|
||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(unwrap, out), graph_info
|
||||
|
||||
|
@ -181,13 +218,15 @@ def profile_function(target: 'Target') -> Callable:
|
|||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if kwargs.get('inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||
# still run the profiling but discard some results regarding `target`
|
||||
inplace = kwargs.get('inplace', False)
|
||||
if inplace:
|
||||
kwargs['inplace'] = False
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
if inplace:
|
||||
if target in [torch.nn.functional.relu]:
|
||||
meta.save_fwd_in = False
|
||||
meta.bwd_mem_out = 0
|
||||
return out, meta
|
||||
|
||||
f.__name__ = target.__name__
|
||||
|
@ -228,13 +267,17 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if getattr(module, 'inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||
# still run the profiling but discard some results regarding `module`.
|
||||
inplace = getattr(module, 'inplace', False)
|
||||
if inplace:
|
||||
module.inplace = False
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
if inplace:
|
||||
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
|
||||
# is the only inplace activation function that discard its input.
|
||||
if type(module) in [torch.nn.ReLU]:
|
||||
meta.save_fwd_in = False
|
||||
meta.bwd_mem_out = 0
|
||||
return out, meta
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
|
|
|
@ -7,6 +7,7 @@ __all__ = ['MetaTensor']
|
|||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||
`fake_device` is the device that `MetaTensor` is supposed to run on.
|
||||
"""
|
||||
|
||||
_tensor: torch.Tensor
|
||||
|
@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor):
|
|||
__slots__ = ['_tensor']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
def __new__(cls, elem, fake_device=None):
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
|
@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor):
|
|||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device='cpu',
|
||||
device=fake_device if fake_device is not None else elem.device,
|
||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor({self._tensor})"
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
fake_device = None
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = MetaTensor(x)
|
||||
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaTensor):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor):
|
|||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
|
|
|
@ -1,10 +1,21 @@
|
|||
from colossalai.fx.profiler.memory import activation_size
|
||||
import torch
|
||||
from torch.fx import Node, Graph
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def is_autogradable(x):
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
|
||||
"""Trace forward and backward graph with MetaTensor
|
||||
|
||||
Args:
|
||||
|
@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
__slots__ = ['_tensor', '_node']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor, placeholder=False, name=None):
|
||||
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
tensor.size(),
|
||||
|
@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
storage_offset=tensor.storage_offset(),
|
||||
dtype=tensor.dtype,
|
||||
layout=tensor.layout,
|
||||
device='cpu',
|
||||
device=fake_device if fake_device is not None else tensor.device,
|
||||
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = tensor
|
||||
if placeholder:
|
||||
|
@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
'placeholder', (graph._root,),
|
||||
name=namespace.create_name(name, tensor))
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = MetaProxy(x)
|
||||
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaProxy):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
# assert not isinstance(x, MetaProxy)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
def get_node(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
||||
|
@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
|
@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return MetaProxy(
|
||||
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
|
@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
|||
return out
|
||||
|
||||
def wrap(x):
|
||||
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x
|
||||
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
module(*args, **kwargs).sum().backward()
|
||||
out = module(*args, **kwargs)
|
||||
|
||||
for tensor in normalize_tuple(out):
|
||||
if is_autogradable(tensor) and tensor.requires_grad:
|
||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||
torch.autograd.backward(tensor,
|
||||
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
|
||||
retain_graph=True)
|
||||
return graph
|
||||
|
|
|
@ -33,8 +33,9 @@ class MLP(torch.nn.Module):
|
|||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_comm_size_compute():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
model = MLP(MODEL_DIM)
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||
|
|
|
@ -62,10 +62,7 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
|||
|
||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||
x.requires_grad = requires_backward
|
||||
meta_x = MetaTensor(x.to('meta'))
|
||||
if isinstance(f, nn.Module):
|
||||
x_out, meta_out = f(x), f.to('meta')(meta_x)
|
||||
else:
|
||||
meta_x = MetaTensor(x)
|
||||
x_out, meta_out = f(x), f(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
if requires_backward:
|
||||
|
|
|
@ -30,17 +30,17 @@ tmm_models = [
|
|||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models():
|
||||
for m in tm_models:
|
||||
model = m().to('meta')
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data)).sum().backward()
|
||||
model = m()
|
||||
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models():
|
||||
for m in tmm_models:
|
||||
model = m().to('meta')
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data)).sum().backward()
|
||||
model = m()
|
||||
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
import torchvision.models as tm
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx import meta_trace
|
||||
|
||||
tm_models = [
|
||||
tm.vgg11,
|
||||
tm.resnet18,
|
||||
tm.densenet121,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.resnext50_32x4d,
|
||||
tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf,
|
||||
tm.mnasnet0_5,
|
||||
tm.efficientnet_b0,
|
||||
]
|
||||
|
||||
tmm_models = [
|
||||
tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m,
|
||||
tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224,
|
||||
tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100,
|
||||
tmm.swin_transformer.swin_base_patch4_window7_224
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models_trace():
|
||||
for m in tm_models:
|
||||
model = m()
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
graph = meta_trace(model, torch.device('cpu'), data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models_trace():
|
||||
for m in tmm_models:
|
||||
model = m()
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
graph = meta_trace(model, torch.device('cpu'), data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_torchvision_models_trace()
|
||||
test_timm_models_trace()
|
|
@ -1,12 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
|
||||
import pytest
|
||||
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
DIM_OUT = 16
|
||||
|
@ -22,6 +18,9 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
|||
def test_meta_info_prop():
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
|
|
Loading…
Reference in New Issue