mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fxpull/1617/head^2
parent
f7f2248771
commit
d967779a32
|
@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor):
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(aten.hardtanh.default)
|
||||||
|
def meta_hardtanh(input: torch.Tensor, min, max):
|
||||||
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.hardswish_backward.default)
|
@register_meta(aten.hardswish_backward.default)
|
||||||
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
|
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
|
||||||
grad_in = torch.empty_like(input)
|
grad_in = torch.empty_like(input)
|
||||||
|
@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
|
||||||
|
|
||||||
@register_meta(aten.roll.default)
|
@register_meta(aten.roll.default)
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
return torch.empty_like(input)
|
return input
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_batch_norm.default)
|
@register_meta(aten.native_batch_norm.default)
|
||||||
|
@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||||
return dX, dgamma, dbeta
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_layer_norm.default)
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
@register_meta(aten.cudnn_batch_norm.default)
|
||||||
|
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
n_input = input.size(1)
|
n_input = input.size(1)
|
||||||
|
|
||||||
output = torch.empty_like(input)
|
output = torch.empty_like(input)
|
||||||
running_mean = torch.empty((n_input), device='meta')
|
running_mean = torch.empty((n_input), device='meta')
|
||||||
running_var = torch.empty((n_input), device='meta')
|
running_var = torch.empty((n_input), device='meta')
|
||||||
|
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
|
||||||
|
return output, running_mean, running_var, reserve
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
|
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||||
|
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||||
|
# which is why this doesn't accept a 'training' parameter.
|
||||||
|
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||||
|
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||||
|
save_mean, save_invstd, eps, reserve):
|
||||||
|
dX = torch.empty_like(input)
|
||||||
|
dgamma = torch.empty_like(weight)
|
||||||
|
dbeta = torch.empty_like(weight)
|
||||||
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(aten.native_layer_norm.default)
|
||||||
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
|
bs = input.size(0)
|
||||||
|
n_input = input.size(1)
|
||||||
|
|
||||||
|
output = torch.empty_like(input)
|
||||||
|
running_mean = torch.empty((bs, n_input, 1), device='meta')
|
||||||
|
running_var = torch.empty((bs, n_input, 1), device='meta')
|
||||||
return output, running_mean, running_var
|
return output, running_mean, running_var
|
||||||
|
|
||||||
|
|
||||||
|
@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
|
||||||
layout=grad_output.layout)
|
layout=grad_output.layout)
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||||
@register_meta(aten.where.self)
|
@register_meta(aten.where.self)
|
||||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||||
return torch.empty_like(condition)
|
result_type = torch.result_type(self, other)
|
||||||
|
return torch.empty_like(self, dtype=result_type)
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
|
@register_meta(aten.native_dropout.default)
|
||||||
|
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||||
|
# notice that mask is bool
|
||||||
|
output = torch.empty_like(input)
|
||||||
|
mask = torch.empty_like(input, dtype=torch.bool)
|
||||||
|
return output, mask
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
|
@register_meta(aten.native_dropout_backward.default)
|
||||||
|
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||||
|
return torch.empty_like(grad)
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import List, Tuple
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from .linearize import linearize
|
||||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||||
|
@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
|
||||||
|
|
||||||
xbar = 0
|
xbar = 0
|
||||||
for n in node:
|
for n in node:
|
||||||
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
|
xbar += n.meta['fwd_mem_tmp']
|
||||||
|
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
|
||||||
|
xbar += n.meta['fwd_mem_out']
|
||||||
return xbar
|
return xbar
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||||
def _get_deps_size():
|
def _get_deps_size():
|
||||||
deps_size = 0
|
deps_size = 0
|
||||||
for k, v in deps.items():
|
for k, v in deps.items():
|
||||||
|
k: Node
|
||||||
if v > 0:
|
if v > 0:
|
||||||
deps_size += k.meta['bwd_mem_out']
|
deps_size += k.meta['bwd_mem_out']
|
||||||
if v == float('-inf'):
|
if v == float('-inf'):
|
||||||
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
|
deps_size -= k.meta['fwd_mem_tmp']
|
||||||
|
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
|
||||||
|
deps_size -= k.meta['fwd_mem_out']
|
||||||
|
|
||||||
return deps_size
|
return deps_size
|
||||||
|
|
||||||
|
@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_list = linearize(gm, cnode)
|
node_list = linearize(gm, cnode)
|
||||||
mem_limit -= parameter_size(gm)
|
|
||||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||||
|
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
|
|
||||||
chain: Chain = _construct_chain(node_list, data)
|
chain: Chain = _construct_chain(node_list, data)
|
||||||
|
|
|
@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||||
n.meta['tensor_meta'] = tensor_meta
|
n.meta['tensor_meta'] = tensor_meta
|
||||||
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
|
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||||
# TODO: the attribute node_size should be removed in the future
|
# TODO: the attribute node_size should be removed in the future
|
||||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||||
for par in n.all_input_nodes:
|
|
||||||
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
|
|
||||||
n.meta['type'] = type(result)
|
n.meta['type'] = type(result)
|
||||||
|
|
||||||
# retain the autograd graph
|
# retain the autograd graph
|
||||||
|
@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
"""
|
"""
|
||||||
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
|
return args[0], GraphInfo(save_fwd_in=True)
|
||||||
|
|
||||||
def propagate(self, *args):
|
def propagate(self, *args):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import torch
|
||||||
|
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||||
|
from . import META_COMPATIBILITY
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
ALIAS_ATEN = [
|
||||||
|
# inplace reshaping
|
||||||
|
aten.detach.default,
|
||||||
|
aten.t.default,
|
||||||
|
aten.transpose.int,
|
||||||
|
aten.view.default,
|
||||||
|
aten._unsafe_view.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
INPLACE_NEW = [
|
||||||
|
aten.empty_like.default,
|
||||||
|
aten.new_empty_strided.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
INPLACE_MATH_ATEN = [
|
||||||
|
aten.add_.Tensor,
|
||||||
|
aten.sub_.Tensor,
|
||||||
|
aten.div_.Tensor,
|
||||||
|
aten.div_.Scalar,
|
||||||
|
aten.mul_.Tensor,
|
||||||
|
aten.bernoulli_.float,
|
||||||
|
]
|
||||||
|
|
||||||
|
CLONE_ATEN = [
|
||||||
|
aten.clone.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TODO fill out the inplace ops
|
||||||
|
INPLACE_OPS = [
|
||||||
|
add,
|
||||||
|
sub,
|
||||||
|
mul,
|
||||||
|
floordiv,
|
||||||
|
neg,
|
||||||
|
pos,
|
||||||
|
getitem,
|
||||||
|
setitem,
|
||||||
|
getattr,
|
||||||
|
torch.Tensor.cpu,
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: list all call_methods that are inplace here
|
||||||
|
INPLACE_METHOD = [
|
||||||
|
'transpose',
|
||||||
|
'permute',
|
||||||
|
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||||
|
'reshape',
|
||||||
|
'dim',
|
||||||
|
'flatten',
|
||||||
|
'size',
|
||||||
|
'view',
|
||||||
|
'unsqueeze',
|
||||||
|
'to',
|
||||||
|
'type',
|
||||||
|
'flatten',
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: list all call_methods that are not inplace here
|
||||||
|
NON_INPLACE_METHOD = [
|
||||||
|
'chunk',
|
||||||
|
'contiguous',
|
||||||
|
'expand',
|
||||||
|
'mean',
|
||||||
|
'split',
|
||||||
|
]
|
||||||
|
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
|
@ -3,9 +3,6 @@ from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from .memory import activation_size, is_inplace
|
from .memory import activation_size, is_inplace
|
||||||
from . import META_COMPATIBILITY
|
|
||||||
if META_COMPATIBILITY:
|
|
||||||
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
|
@ -23,29 +20,32 @@ class GraphInfo:
|
||||||
============================================================================
|
============================================================================
|
||||||
-------------------------------
|
-------------------------------
|
||||||
| Node |
|
| Node |
|
||||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
|
||||||
placeholders saved for | | \__________ | |
|
placeholders saved for | | \__________ | |
|
||||||
backward. | | \ | |
|
backward. | | \ | |
|
||||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||||
| / \ \ | | in backward pass.
|
| / \ \ | | in backward pass.
|
||||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||||
in [fwd_tmp] because | | | \_____ | |
|
in [fwd_tmp] because | | \_____ | |
|
||||||
it is not saved for | | | \ | |
|
it is not saved for | | \ | |
|
||||||
backward. -------------------------------
|
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
|
||||||
|
-------------------------------
|
||||||
============================================================================
|
============================================================================
|
||||||
Attributes:
|
Attributes:
|
||||||
fwd_flop (int): The forward FLOPs of a certain node
|
fwd_flop (int): The forward FLOPs of a certain node
|
||||||
bwd_flop (int): The backward FLOPs of a certain node.
|
bwd_flop (int): The backward FLOPs of a certain node.
|
||||||
fwd_mem_in (int): See the above illustration.
|
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||||
fwd_mem_tmp (int): See the above illustration.
|
fwd_mem_tmp (int): See the above illustration.
|
||||||
|
fwd_mem_out (int): See the above illustration.
|
||||||
bwd_mem_tmp (int): See the above illustration.
|
bwd_mem_tmp (int): See the above illustration.
|
||||||
bwd_mem_out (int): See the above illustration.
|
bwd_mem_out (int): See the above illustration.
|
||||||
"""
|
"""
|
||||||
fwd_flop: int = 0
|
fwd_flop: int = 0
|
||||||
bwd_flop: int = 0
|
bwd_flop: int = 0
|
||||||
fwd_mem_in: int = 0
|
save_fwd_in: bool = False
|
||||||
fwd_mem_tmp: int = 0
|
fwd_mem_tmp: int = 0
|
||||||
|
fwd_mem_out: int = 0
|
||||||
bwd_mem_tmp: int = 0
|
bwd_mem_tmp: int = 0
|
||||||
bwd_mem_out: int = 0
|
bwd_mem_out: int = 0
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def is_saved(n: Node):
|
def is_saved(n: Node):
|
||||||
return n.meta.get('saved', False)
|
return len(n.meta['saved_tensor'])
|
||||||
|
|
||||||
|
|
||||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
|
@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
def _peak_memory(deps: Dict[Node, int]):
|
def _peak_memory(deps: Dict[Node, int]):
|
||||||
peak_mem = 0
|
peak_mem = 0
|
||||||
for k, v in deps.items():
|
for k, v in deps.items():
|
||||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
|
||||||
peak_mem += activation_size(k.meta['out'])
|
peak_mem += activation_size(k.meta['saved_tensor'])
|
||||||
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
|
||||||
peak_mem -= activation_size(k.meta['out'])
|
peak_mem -= activation_size(k.meta['saved_tensor'])
|
||||||
return peak_mem
|
return peak_mem
|
||||||
|
|
||||||
# deps is used to track all the memory dependencies of the graph.
|
# deps is used to track all the memory dependencies of the graph.
|
||||||
|
@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
|
|
||||||
for n in graph.nodes:
|
for n in graph.nodes:
|
||||||
n: Node
|
n: Node
|
||||||
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
deps[n] = len(n.users)
|
||||||
# A forward tensor who is marked `save` but is not
|
# A forward tensor who is marked `save` but is also
|
||||||
# an input to `loss` should be saved during forward.
|
# an input to `Phase.FORWARD` should be saved during forward.
|
||||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||||
# Any `fwd_mem_in` should be kept in memory even this function
|
# Any `fwd_mem_in` should be kept in memory even this function
|
||||||
# is checkpointed.
|
# is checkpointed.
|
||||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||||
# the node, `fwd_mem_tmp` can be freed.
|
# the node, `fwd_mem_tmp` can be freed.
|
||||||
if is_phase(n, Phase.PLACEHOLDER):
|
if is_phase(n, Phase.PLACEHOLDER):
|
||||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
||||||
if is_phase(n, Phase.FORWARD):
|
if is_phase(n, Phase.FORWARD):
|
||||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
||||||
elif is_phase(n, Phase.BACKWARD):
|
elif is_phase(n, Phase.BACKWARD):
|
||||||
if len(n.users):
|
if len(n.users):
|
||||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||||
else:
|
else:
|
||||||
# TODO: some of the bwd_mem_out might be model parameters.
|
# TODO: some of the bwd_mem_out might be model parameters.
|
||||||
# basically a backward node without user is a `grad_out` node
|
# basically a backward node without user is a `grad_out` node
|
||||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
|
||||||
for input_n in n.all_input_nodes:
|
for input_n in n.all_input_nodes:
|
||||||
if input_n in deps:
|
if input_n in deps:
|
||||||
deps[input_n] -= 1
|
deps[input_n] -= 1
|
||||||
|
|
|
@ -3,7 +3,8 @@ from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from . import meta_profiler_function, meta_profiler_module
|
from . import meta_profiler_function, meta_profiler_module
|
||||||
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
from ..memory import activation_size
|
||||||
|
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
|
|
@ -1,88 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
from typing import Union, Dict, List, Tuple
|
from typing import Union, Dict, List, Tuple
|
||||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
|
||||||
from . import META_COMPATIBILITY
|
from . import META_COMPATIBILITY
|
||||||
|
|
||||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
|
||||||
aten = torch.ops.aten
|
|
||||||
|
|
||||||
WEIRD_OPS = [
|
|
||||||
torch.where,
|
|
||||||
]
|
|
||||||
|
|
||||||
INPLACE_ATEN = [
|
|
||||||
aten.add_.Tensor,
|
|
||||||
aten.sub_.Tensor,
|
|
||||||
aten.div_.Tensor,
|
|
||||||
aten.div_.Scalar,
|
|
||||||
aten.mul_.Tensor,
|
|
||||||
aten.bernoulli_.float,
|
|
||||||
|
|
||||||
# inplace reshaping
|
|
||||||
aten.copy_.default,
|
|
||||||
aten.detach.default,
|
|
||||||
aten.t.default,
|
|
||||||
aten.transpose.int,
|
|
||||||
aten.view.default,
|
|
||||||
aten._unsafe_view.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
NORMALIZATION_ATEN = [
|
|
||||||
aten.native_batch_norm.default,
|
|
||||||
aten.native_layer_norm.default,
|
|
||||||
# aten.max_pool2d_with_indices.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
CLONE_ATEN = [
|
|
||||||
aten.clone.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO fill out the inplace ops
|
|
||||||
INPLACE_OPS = [
|
|
||||||
add,
|
|
||||||
sub,
|
|
||||||
mul,
|
|
||||||
floordiv,
|
|
||||||
neg,
|
|
||||||
pos,
|
|
||||||
getitem,
|
|
||||||
setitem,
|
|
||||||
getattr,
|
|
||||||
torch.Tensor.cpu,
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: list all call_methods that are inplace here
|
|
||||||
INPLACE_METHOD = [
|
|
||||||
'transpose',
|
|
||||||
'permute',
|
|
||||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
|
||||||
'reshape',
|
|
||||||
'dim',
|
|
||||||
'flatten',
|
|
||||||
'size',
|
|
||||||
'view',
|
|
||||||
'unsqueeze',
|
|
||||||
'to',
|
|
||||||
'type',
|
|
||||||
'flatten',
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: list all call_methods that are not inplace here
|
|
||||||
NON_INPLACE_METHOD = [
|
|
||||||
'chunk',
|
|
||||||
'contiguous',
|
|
||||||
'expand',
|
|
||||||
'mean',
|
|
||||||
'split',
|
|
||||||
]
|
|
||||||
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
|
||||||
|
|
||||||
|
|
||||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
"""Calculate activation size of a node.
|
"""Calculate activation size of a node.
|
||||||
|
@ -106,13 +28,13 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
|
|
||||||
|
|
||||||
def parameter_size(mod: torch.nn.Module) -> int:
|
def parameter_size(mod: torch.nn.Module) -> int:
|
||||||
"""Calculate param size of a node.
|
"""Calculate parameter size of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The param size
|
int: The parameter size
|
||||||
"""
|
"""
|
||||||
param_size = 0
|
param_size = 0
|
||||||
for param in mod.parameters():
|
for param in mod.parameters():
|
||||||
|
@ -132,8 +54,10 @@ def is_inplace(n: Node):
|
||||||
inplace = False
|
inplace = False
|
||||||
if n.op == "call_function":
|
if n.op == "call_function":
|
||||||
inplace = n.kwargs.get("inplace", False)
|
inplace = n.kwargs.get("inplace", False)
|
||||||
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
|
if META_COMPATIBILITY:
|
||||||
inplace = True
|
from .constant import ALIAS_ATEN
|
||||||
|
if n.target in ALIAS_ATEN:
|
||||||
|
inplace = True
|
||||||
elif n.op == "call_module":
|
elif n.op == "call_module":
|
||||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
||||||
# ideas from https://pastebin.com/AkvAyJBw
|
# ideas from https://pastebin.com/AkvAyJBw
|
||||||
|
|
||||||
from functools import reduce
|
from functools import partial, reduce
|
||||||
import operator
|
import operator
|
||||||
from typing import Callable, List, Any
|
from typing import Callable, List, Any
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
@ -147,8 +147,9 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
|
||||||
return norm_flop_jit
|
return norm_flop_jit
|
||||||
|
|
||||||
|
|
||||||
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
|
||||||
training = inputs[-3]
|
if training is None:
|
||||||
|
training = inputs[-3]
|
||||||
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
||||||
if training:
|
if training:
|
||||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||||
|
@ -201,6 +202,8 @@ flop_mapping = {
|
||||||
# normalization
|
# normalization
|
||||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||||
|
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||||
|
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||||
|
|
||||||
|
@ -247,12 +250,14 @@ elementwise_flop_aten = [
|
||||||
aten.hardswish.default,
|
aten.hardswish.default,
|
||||||
aten.hardswish_.default,
|
aten.hardswish_.default,
|
||||||
aten.hardswish_backward.default,
|
aten.hardswish_backward.default,
|
||||||
|
aten.hardtanh.default,
|
||||||
aten.hardtanh_.default,
|
aten.hardtanh_.default,
|
||||||
aten.hardtanh_backward.default,
|
aten.hardtanh_backward.default,
|
||||||
aten.hardsigmoid_backward.default,
|
aten.hardsigmoid_backward.default,
|
||||||
aten.hardsigmoid.default,
|
aten.hardsigmoid.default,
|
||||||
aten.gelu.default,
|
aten.gelu.default,
|
||||||
aten.gelu_backward.default,
|
aten.gelu_backward.default,
|
||||||
|
aten.silu.default,
|
||||||
aten.silu_.default,
|
aten.silu_.default,
|
||||||
aten.silu_backward.default,
|
aten.silu_backward.default,
|
||||||
aten.sigmoid.default,
|
aten.sigmoid.default,
|
||||||
|
@ -264,6 +269,10 @@ elementwise_flop_aten = [
|
||||||
aten.tanh.default,
|
aten.tanh.default,
|
||||||
aten.tanh_backward.default,
|
aten.tanh_backward.default,
|
||||||
aten.threshold_backward.default,
|
aten.threshold_backward.default,
|
||||||
|
|
||||||
|
# dropout
|
||||||
|
aten.native_dropout.default,
|
||||||
|
aten.native_dropout_backward.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
for op in elementwise_flop_aten:
|
for op in elementwise_flop_aten:
|
||||||
|
|
|
@ -1,15 +1,21 @@
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
|
||||||
from .memory import WEIRD_OPS
|
from .memory import activation_size
|
||||||
|
from .constant import ALIAS_ATEN
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
# super-dainiu: this cache should be global, otherwise it cannot
|
||||||
|
# track duplicated tensors between nodes
|
||||||
|
cache = set()
|
||||||
|
|
||||||
|
|
||||||
def normalize_tuple(x):
|
def normalize_tuple(x):
|
||||||
if not isinstance(x, tuple):
|
if not isinstance(x, tuple):
|
||||||
|
@ -21,7 +27,17 @@ def is_autogradable(x):
|
||||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||||
|
|
||||||
|
|
||||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
# super-dainiu:
|
||||||
|
# x.detach() will change the unique identifier of data_ptr
|
||||||
|
# we need to handle this in a stupid way
|
||||||
|
def detach(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
requires_grad = x.requires_grad
|
||||||
|
x.requires_grad_(False)
|
||||||
|
x.requires_grad_(requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||||
"""
|
"""
|
||||||
Profile a Callable function with args and kwargs.
|
Profile a Callable function with args and kwargs.
|
||||||
|
|
||||||
|
@ -55,8 +71,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.grad_fn:
|
if self.grad_fn:
|
||||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
|
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
@ -68,27 +84,47 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
kwargs_node = tree_map(get_node, kwargs)
|
kwargs_node = tree_map(get_node, kwargs)
|
||||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||||
|
|
||||||
# do not allocate on `cpu`
|
# do not allocate on physical devices
|
||||||
if 'device' in kwargs:
|
if 'device' in kwargs:
|
||||||
kwargs['device'] = 'meta'
|
fake_device = kwargs['device']
|
||||||
|
kwargs['device'] = torch.device('meta')
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
nonlocal fake_device
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
if isinstance(x, MetaTensor):
|
||||||
x = FlopTensor(x.to('meta'))
|
fake_device = x.device
|
||||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
x = x._tensor
|
||||||
|
elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||||
|
fake_device = x.device
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return x
|
||||||
|
|
||||||
args = tree_map(unwrap, args)
|
args = tree_map(unwrap, args)
|
||||||
kwargs = tree_map(unwrap, kwargs)
|
kwargs = tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
# run aten for backend=CPU but actually on backend=Meta
|
# run aten for backend=WHATEVER but actually on backend=Meta
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||||
node.meta['out'] = normalize_tuple(out)
|
|
||||||
node.meta['phase'] = phase
|
node.meta['phase'] = phase
|
||||||
|
|
||||||
|
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
|
||||||
|
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
|
||||||
|
# `Phase.FORWARD`
|
||||||
|
if phase == Phase.FORWARD:
|
||||||
|
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
||||||
|
node.meta['phase'] = Phase.PLACEHOLDER
|
||||||
|
|
||||||
|
# TODO: specify `saved_tensors` for backward memory estimation
|
||||||
|
node.meta['saved_tensor'] = []
|
||||||
|
if phase == Phase.BACKWARD:
|
||||||
|
node.meta['saved_tensor'] = normalize_tuple(out)
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
if isinstance(x, torch.Tensor):
|
||||||
|
nonlocal fake_device
|
||||||
|
if not x.is_meta:
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
def set_node(x):
|
def set_node(x):
|
||||||
x._node = node
|
x._node = node
|
||||||
|
@ -97,18 +133,13 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
tree_map(set_node, out)
|
tree_map(set_node, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# `WEIRD_OPS` are tough to handle because they don't accept autograd
|
def wrap(x):
|
||||||
# on meta tensor.
|
fake_device = None
|
||||||
if target not in WEIRD_OPS:
|
if isinstance(x, MetaTensor):
|
||||||
|
fake_device = x.device
|
||||||
def wrap(x):
|
x = x._tensor
|
||||||
return FlopTensor(
|
detach(x)
|
||||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x
|
||||||
else:
|
|
||||||
|
|
||||||
def wrap(x):
|
|
||||||
return FlopTensor(
|
|
||||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
|
||||||
|
|
||||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||||
args = tree_map(wrap, args)
|
args = tree_map(wrap, args)
|
||||||
|
@ -120,14 +151,16 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
'placeholder', (subgraph._root,),
|
'placeholder', (subgraph._root,),
|
||||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||||
x._node.meta['out'] = (x._tensor,)
|
x._node.meta['saved_tensor'] = []
|
||||||
|
|
||||||
tree_map(set_placeholder, args)
|
tree_map(set_placeholder, args)
|
||||||
tree_map(set_placeholder, kwargs)
|
tree_map(set_placeholder, kwargs)
|
||||||
|
|
||||||
def pack(x):
|
def pack(x):
|
||||||
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
global cache
|
||||||
x._node.meta['saved'] = True
|
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
||||||
|
x._node.meta['saved_tensor'] += [x._tensor]
|
||||||
|
cache.add(x._tensor.data_ptr)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def unpack(x):
|
def unpack(x):
|
||||||
|
@ -146,19 +179,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||||
|
|
||||||
# If the output is not a floating point `torch.Tensor` or it does not
|
# If the output is not a floating point `torch.Tensor` or it does not
|
||||||
# requires grad, then we should not run backward for this node.
|
# requires grad, then we should not run backward for this node.
|
||||||
if is_autogradable(out) and out.requires_grad:
|
for tensor in normalize_tuple(out):
|
||||||
phase = Phase.BACKWARD
|
if is_autogradable(tensor) and tensor.requires_grad:
|
||||||
if isinstance(out, FlopTensor):
|
phase = Phase.BACKWARD
|
||||||
out._node.meta['save'] = False
|
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||||
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||||
out, device='meta')
|
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
|
||||||
torch.autograd.backward(out, FlopTensor(grad))
|
|
||||||
|
|
||||||
graph_info = autograd_graph_analysis(subgraph)
|
graph_info = autograd_graph_analysis(subgraph)
|
||||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||||
|
graph_info.fwd_mem_out = activation_size(out)
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
if isinstance(x, FlopTensor):
|
||||||
|
fake_device = x.device
|
||||||
|
x = x._tensor
|
||||||
|
detach(x)
|
||||||
|
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
return tree_map(unwrap, out), graph_info
|
return tree_map(unwrap, out), graph_info
|
||||||
|
|
||||||
|
@ -181,13 +218,15 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
|
||||||
# If there is an argument that this `call_function` is inplace, we should
|
# If there is an argument that this `call_function` is inplace, we should
|
||||||
# skip the autograd profiling.
|
# still run the profiling but discard some results regarding `target`
|
||||||
if kwargs.get('inplace', False):
|
inplace = kwargs.get('inplace', False)
|
||||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
if inplace:
|
||||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
kwargs['inplace'] = False
|
||||||
out = func(*args, **kwargs)
|
|
||||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
|
||||||
out, meta = _profile(func, *args, **kwargs)
|
out, meta = _profile(func, *args, **kwargs)
|
||||||
|
if inplace:
|
||||||
|
if target in [torch.nn.functional.relu]:
|
||||||
|
meta.save_fwd_in = False
|
||||||
|
meta.bwd_mem_out = 0
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
f.__name__ = target.__name__
|
f.__name__ = target.__name__
|
||||||
|
@ -228,13 +267,17 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
|
||||||
# If there is an argument that this `call_module` is inplace, we should
|
# If there is an argument that this `call_module` is inplace, we should
|
||||||
# skip the autograd profiling.
|
# still run the profiling but discard some results regarding `module`.
|
||||||
if getattr(module, 'inplace', False):
|
inplace = getattr(module, 'inplace', False)
|
||||||
args = tree_map(lambda x: x.to('meta'), args)
|
if inplace:
|
||||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
module.inplace = False
|
||||||
out = func(*args, **kwargs)
|
|
||||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
|
||||||
out, meta = _profile(func, *args, **kwargs)
|
out, meta = _profile(func, *args, **kwargs)
|
||||||
|
if inplace:
|
||||||
|
# super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU`
|
||||||
|
# is the only inplace activation function that discard its input.
|
||||||
|
if type(module) in [torch.nn.ReLU]:
|
||||||
|
meta.save_fwd_in = False
|
||||||
|
meta.bwd_mem_out = 0
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
f.__name__ = module.__class__.__name__
|
f.__name__ = module.__class__.__name__
|
||||||
|
|
|
@ -7,6 +7,7 @@ __all__ = ['MetaTensor']
|
||||||
class MetaTensor(torch.Tensor):
|
class MetaTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||||
|
`fake_device` is the device that `MetaTensor` is supposed to run on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_tensor: torch.Tensor
|
_tensor: torch.Tensor
|
||||||
|
@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor):
|
||||||
__slots__ = ['_tensor']
|
__slots__ = ['_tensor']
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, elem):
|
def __new__(cls, elem, fake_device=None):
|
||||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||||
# memory for the class in question, but it should still
|
# memory for the class in question, but it should still
|
||||||
# advertise the same device as before
|
# advertise the same device as before
|
||||||
|
@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor):
|
||||||
storage_offset=elem.storage_offset(),
|
storage_offset=elem.storage_offset(),
|
||||||
dtype=elem.dtype,
|
dtype=elem.dtype,
|
||||||
layout=elem.layout,
|
layout=elem.layout,
|
||||||
device='cpu',
|
device=fake_device if fake_device is not None else elem.device,
|
||||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||||
r._tensor = elem
|
r._tensor = elem
|
||||||
# ...the real tensor is held as an element on the tensor.
|
# ...the real tensor is held as an element on the tensor.
|
||||||
|
if not r._tensor.is_meta:
|
||||||
|
r._tensor = r._tensor.to(torch.device('meta'))
|
||||||
|
# only tensor not on `meta` should be copied to `meta`
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.grad_fn:
|
if self.grad_fn:
|
||||||
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
|
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
|
||||||
return f"MetaTensor({self._tensor})"
|
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
fake_device = None
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
nonlocal fake_device
|
||||||
x = MetaTensor(x)
|
if isinstance(x, MetaTensor):
|
||||||
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
|
fake_device = x.device
|
||||||
|
x = x._tensor
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
fake_device = x.device
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return x
|
||||||
|
|
||||||
|
if 'device' in kwargs:
|
||||||
|
fake_device = kwargs['device']
|
||||||
|
kwargs['device'] = torch.device('meta')
|
||||||
|
|
||||||
args = tree_map(unwrap, args)
|
args = tree_map(unwrap, args)
|
||||||
kwargs = tree_map(unwrap, kwargs)
|
kwargs = tree_map(unwrap, kwargs)
|
||||||
|
@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor):
|
||||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||||
# our custom tensor subclass
|
# our custom tensor subclass
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
if isinstance(x, torch.Tensor):
|
||||||
|
nonlocal fake_device
|
||||||
|
if not x.is_meta:
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
return tree_map(wrap, out)
|
return tree_map(wrap, out)
|
||||||
|
|
|
@ -1,10 +1,21 @@
|
||||||
|
from colossalai.fx.profiler.memory import activation_size
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node, Graph
|
from torch.fx import Node, Graph
|
||||||
from torch.fx.graph import _Namespace
|
from torch.fx.graph import _Namespace
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
|
||||||
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
def normalize_tuple(x):
|
||||||
|
if not isinstance(x, tuple):
|
||||||
|
return (x,)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def is_autogradable(x):
|
||||||
|
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||||
|
|
||||||
|
|
||||||
|
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
|
||||||
"""Trace forward and backward graph with MetaTensor
|
"""Trace forward and backward graph with MetaTensor
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
__slots__ = ['_tensor', '_node']
|
__slots__ = ['_tensor', '_node']
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, tensor, placeholder=False, name=None):
|
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
|
||||||
r = torch.Tensor._make_wrapper_subclass(
|
r = torch.Tensor._make_wrapper_subclass(
|
||||||
cls,
|
cls,
|
||||||
tensor.size(),
|
tensor.size(),
|
||||||
|
@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
storage_offset=tensor.storage_offset(),
|
storage_offset=tensor.storage_offset(),
|
||||||
dtype=tensor.dtype,
|
dtype=tensor.dtype,
|
||||||
layout=tensor.layout,
|
layout=tensor.layout,
|
||||||
device='cpu',
|
device=fake_device if fake_device is not None else tensor.device,
|
||||||
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
||||||
r._tensor = tensor
|
r._tensor = tensor
|
||||||
if placeholder:
|
if placeholder:
|
||||||
|
@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
'placeholder', (graph._root,),
|
'placeholder', (graph._root,),
|
||||||
name=namespace.create_name(name, tensor))
|
name=namespace.create_name(name, tensor))
|
||||||
# ...the real tensor is held as an element on the tensor.
|
# ...the real tensor is held as an element on the tensor.
|
||||||
|
if not r._tensor.is_meta:
|
||||||
|
r._tensor = r._tensor.to(torch.device('meta'))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
nonlocal fake_device
|
||||||
x = MetaProxy(x)
|
if isinstance(x, MetaProxy):
|
||||||
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x
|
fake_device = x.device
|
||||||
|
x = x._tensor
|
||||||
|
# assert not isinstance(x, MetaProxy)
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
fake_device = x.device
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return x
|
||||||
|
|
||||||
def get_node(x):
|
def get_node(x):
|
||||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
||||||
|
@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
kwargs_node = tree_map(get_node, kwargs)
|
kwargs_node = tree_map(get_node, kwargs)
|
||||||
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
||||||
|
|
||||||
|
if 'device' in kwargs:
|
||||||
|
fake_device = kwargs['device']
|
||||||
|
kwargs['device'] = torch.device('meta')
|
||||||
|
|
||||||
args = tree_map(unwrap, args)
|
args = tree_map(unwrap, args)
|
||||||
kwargs = tree_map(unwrap, kwargs)
|
kwargs = tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
|
@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||||
# our custom tensor subclass
|
# our custom tensor subclass
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
if isinstance(x, torch.Tensor):
|
||||||
|
nonlocal fake_device
|
||||||
|
if not x.is_meta:
|
||||||
|
x = x.to(torch.device('meta'))
|
||||||
|
return MetaProxy(
|
||||||
|
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||||
|
|
||||||
def set_node(x):
|
def set_node(x):
|
||||||
x._node = node
|
x._node = node
|
||||||
|
@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x
|
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
args = tree_map(wrap, args)
|
args = tree_map(wrap, args)
|
||||||
kwargs = tree_map(wrap, kwargs)
|
kwargs = tree_map(wrap, kwargs)
|
||||||
|
|
||||||
module(*args, **kwargs).sum().backward()
|
out = module(*args, **kwargs)
|
||||||
|
|
||||||
|
for tensor in normalize_tuple(out):
|
||||||
|
if is_autogradable(tensor) and tensor.requires_grad:
|
||||||
|
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||||
|
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||||
|
torch.autograd.backward(tensor,
|
||||||
|
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
|
||||||
|
retain_graph=True)
|
||||||
return graph
|
return graph
|
||||||
|
|
|
@ -33,8 +33,9 @@ class MLP(torch.nn.Module):
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||||
def test_comm_size_compute():
|
def test_comm_size_compute():
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
model = MLP(MODEL_DIM)
|
model = MLP(MODEL_DIM)
|
||||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||||
|
|
|
@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||||
|
|
||||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||||
x.requires_grad = requires_backward
|
x.requires_grad = requires_backward
|
||||||
meta_x = MetaTensor(x.to('meta'))
|
meta_x = MetaTensor(x)
|
||||||
if isinstance(f, nn.Module):
|
x_out, meta_out = f(x), f(meta_x)
|
||||||
x_out, meta_out = f(x), f.to('meta')(meta_x)
|
|
||||||
else:
|
|
||||||
x_out, meta_out = f(x), f(meta_x)
|
|
||||||
compare_all(x_out, meta_out)
|
compare_all(x_out, meta_out)
|
||||||
if requires_backward:
|
if requires_backward:
|
||||||
x_out.sum().backward()
|
x_out.sum().backward()
|
||||||
|
|
|
@ -30,17 +30,17 @@ tmm_models = [
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
for m in tm_models:
|
for m in tm_models:
|
||||||
model = m().to('meta')
|
model = m()
|
||||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||||
model(MetaTensor(data)).sum().backward()
|
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||||
def test_timm_models():
|
def test_timm_models():
|
||||||
for m in tmm_models:
|
for m in tmm_models:
|
||||||
model = m().to('meta')
|
model = m()
|
||||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||||
model(MetaTensor(data)).sum().backward()
|
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import torchvision.models as tm
|
||||||
|
import timm.models as tmm
|
||||||
|
import torch
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx import meta_trace
|
||||||
|
|
||||||
|
tm_models = [
|
||||||
|
tm.vgg11,
|
||||||
|
tm.resnet18,
|
||||||
|
tm.densenet121,
|
||||||
|
tm.mobilenet_v3_small,
|
||||||
|
tm.resnext50_32x4d,
|
||||||
|
tm.wide_resnet50_2,
|
||||||
|
tm.regnet_x_16gf,
|
||||||
|
tm.mnasnet0_5,
|
||||||
|
tm.efficientnet_b0,
|
||||||
|
]
|
||||||
|
|
||||||
|
tmm_models = [
|
||||||
|
tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m,
|
||||||
|
tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224,
|
||||||
|
tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100,
|
||||||
|
tmm.swin_transformer.swin_base_patch4_window7_224
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||||
|
def test_torchvision_models_trace():
|
||||||
|
for m in tm_models:
|
||||||
|
model = m()
|
||||||
|
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||||
|
graph = meta_trace(model, torch.device('cpu'), data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||||
|
def test_timm_models_trace():
|
||||||
|
for m in tmm_models:
|
||||||
|
model = m()
|
||||||
|
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||||
|
graph = meta_trace(model, torch.device('cpu'), data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_torchvision_models_trace()
|
||||||
|
test_timm_models_trace()
|
|
@ -1,12 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import colossalai
|
|
||||||
import colossalai.nn as col_nn
|
|
||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
DIM_IN = 4
|
DIM_IN = 4
|
||||||
DIM_OUT = 16
|
DIM_OUT = 16
|
||||||
|
@ -22,6 +18,9 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
||||||
def test_meta_info_prop():
|
def test_meta_info_prop():
|
||||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
||||||
orig_output = model(input_sample)
|
orig_output = model(input_sample)
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
|
Loading…
Reference in New Issue