[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)

* [fx/profiler] modify data_ptr into uuid for all tensors.

* [fx] modify uuid.

* [fx/profiler] tune performance on GPT-2.

* [fx] updates.

* [fx] debug.

* [fx] debug.

* [fx] cuda.
pull/1687/head
Super Daniel 2022-10-11 11:03:35 +08:00 committed by GitHub
parent 0df5034a36
commit 3dd6994427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 262 additions and 94 deletions

View File

@ -2,6 +2,7 @@ from typing import List, Set, Tuple
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
import math import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy'] __all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
@ -74,10 +75,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
prev_idx = 2 prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes): for (idx, n) in enumerate(gm.graph.nodes):
n: Node n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp'] temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp) y = max(y, temp)
if temp > b and n in ckpt_nodes: if temp > b and n in ckpt_nodes:
x += n.meta['fwd_mem_out'] x += calculate_fwd_in(n)
temp = 0 temp = 0
ckpt_intv.append((prev_idx, idx + 1)) ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1 prev_idx = idx + 1

View File

@ -1,8 +1,9 @@
import sys import sys
from typing import List, Tuple from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math import math
from .linearize import linearize from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0 xbar = 0
for n in node: for n in node:
xbar += n.meta['fwd_mem_tmp'] xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
return xbar return xbar
@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int:
return bwd_time return bwd_time
def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _get_bwd_mem_tmp(node: List[Node]) -> int: def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node """Get the backward temp memory of a node
@ -184,9 +198,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
if v > 0: if v > 0:
deps_size += k.meta['bwd_mem_out'] deps_size += k.meta['bwd_mem_out']
if v == float('-inf'): if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
return deps_size return deps_size
@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain:
bwd_time = [] bwd_time = []
xbar_sizes = [activation_size(input)] xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)] x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd tmp_fwd = []
tmp_fwd = [0] * len(node_list)
tmp_bwd = [] tmp_bwd = []
for idx, node in enumerate(node_list): for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node)) fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node)) bwd_time.append(_bwd_time(node))
x_sizes.append(node[-1].meta['fwd_mem_out']) x_sizes.append(calculate_fwd_out(node[-1]))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_fwd.append(_get_fwd_mem_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node)) tmp_bwd.append(_get_bwd_mem_tmp(node))
bwd_time.append(0) bwd_time.append(0)

View File

@ -1,12 +1,11 @@
from dataclasses import asdict from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, Argument, Target from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter):
# output of above code is # output of above code is
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- ------------- --------- --------- --------- --------- ----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args: Args:
module (GraphModule): The module to be executed module (GraphModule): The module to be executed
@ -102,7 +101,7 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = tensor_meta n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future # TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
n.meta['type'] = type(result) n.meta['type'] = type(result)
# retain the autograd graph # retain the autograd graph
@ -228,6 +227,8 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
""" """
if hasattr(args[0], '_tensor'):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True) return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args): def propagate(self, *args):
@ -281,9 +282,9 @@ class MetaInfoProp(torch.fx.Interpreter):
str(node), str(node),
flops_repr(node.meta['fwd_flop']), flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']), flops_repr(node.meta['bwd_flop']),
node.meta['save_fwd_in'], mem_repr(calculate_fwd_in(node)),
mem_repr(node.meta['fwd_mem_out']), mem_repr(calculate_fwd_out(node)),
mem_repr(node.meta['fwd_mem_tmp']), mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']), mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']), mem_repr(node.meta['bwd_mem_tmp']),
]) ])
@ -295,7 +296,7 @@ class MetaInfoProp(torch.fx.Interpreter):
'Op', 'Op',
'Forward FLOPs', 'Forward FLOPs',
'Backward FLOPs', 'Backward FLOPs',
'SAVE_FWD_IN', 'FWD_IN',
'FWD_OUT', 'FWD_OUT',
'FWD_TMP', 'FWD_TMP',
'BWD_OUT', 'BWD_OUT',

View File

@ -3,8 +3,9 @@ if META_COMPATIBILITY:
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
else: else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace from .memory import parameter_size, activation_size, is_inplace

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Dict from typing import Dict, List
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .memory import activation_size, is_inplace from .memory import activation_size, is_inplace
@ -39,16 +39,25 @@ class GraphInfo:
bwd_flop (int): The backward FLOPs of a certain node. bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node. bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_in (List): See the above illustration.
fwd_tmp (List): See the above illustration.
fwd_out (List): See the above illustration.
fwd_mem_tmp (int): See the above illustration. fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration. fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration. bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration. bwd_mem_out (int): See the above illustration.
""" """
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
fwd_flop: int = 0 fwd_flop: int = 0
fwd_time: float = 0.0 fwd_time: float = 0.0
bwd_flop: int = 0 bwd_flop: int = 0
bwd_time: float = 0.0 bwd_time: float = 0.0
save_fwd_in: bool = False save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0 fwd_mem_tmp: int = 0
fwd_mem_out: int = 0 fwd_mem_out: int = 0
bwd_mem_tmp: int = 0 bwd_mem_tmp: int = 0
@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase return n.meta['phase'] == phase
def is_saved(n: Node):
return len(n.meta['saved_tensor'])
def autograd_graph_analysis(graph: Graph) -> GraphInfo: def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage. """Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`. Basically the input graph should have all nodes marked for keyword `phase`.
@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed. # the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER): if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0 graph_info.fwd_in += n.meta['saved_tensor']
if is_phase(n, Phase.FORWARD): if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor']) graph_info.fwd_tmp += n.meta['saved_tensor']
elif is_phase(n, Phase.BACKWARD): elif is_phase(n, Phase.BACKWARD):
if len(n.users): if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))

View File

@ -1,4 +1,5 @@
from .registry import meta_profiler_function, meta_profiler_module from .registry import meta_profiler_function, meta_profiler_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .profiler_function import * from .profiler_function import *
from .profiler_module import * from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module

View File

@ -0,0 +1,42 @@
# for PyTorch 1.11 compatibility uses
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']

View File

@ -1,9 +1,11 @@
import torch import torch
from torch.fx import Node from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size', 'is_inplace'] __all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
elif isinstance(out, dict): elif isinstance(out, dict):
value_list = [v for _, v in out.items()] value_list = [v for _, v in out.items()]
act_size += activation_size(value_list) act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list): elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out: for element in out:
act_size += activation_size(element) act_size += activation_size(element)
return act_size return act_size
@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int:
return param_size return param_size
def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
fwd_in (int): the result of `fwd_in`
"""
return activation_size(n.meta["fwd_in"])
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
def is_relu_node(n: Node) -> bool:
if n.op == 'call_function':
return n.target in [torch.nn.functional.relu]
elif n.op == 'call_module':
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
return False
if not is_relu_node(n):
return activation_size(n.meta["fwd_tmp"])
return 0
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
def intersect(a, b):
return {k: a[k] for k in a if k in b}
fwd_in = dict()
for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node): def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node """Get the inplace argument from torch.fx.Node

View File

@ -226,6 +226,7 @@ flop_mapping = {
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
} }
elementwise_flop_aten = [ elementwise_flop_aten = [
@ -304,10 +305,12 @@ zero_flop_aten = [
aten.transpose.int, aten.transpose.int,
aten._to_copy.default, aten._to_copy.default,
aten.unsqueeze.default, aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default, aten._unsafe_view.default,
aten.view.default, aten.view.default,
aten.where.self, aten.where.self,
aten.zero_.default, aten.zero_.default,
aten.zeros_like.default,
] ]
for op in zero_flop_aten: for op in zero_flop_aten:

View File

@ -18,6 +18,9 @@ __all__ = ['profile_function', 'profile_module', 'profile_method']
# track duplicated tensors between nodes # track duplicated tensors between nodes
cache = set() cache = set()
# a global identifier for inplace ops
do_not_cache = False
def normalize_tuple(x): def normalize_tuple(x):
if not isinstance(x, tuple): if not isinstance(x, tuple):
@ -223,10 +226,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
kwargs = tree_map(wrap, kwargs) kwargs = tree_map(wrap, kwargs)
def pack(x): def pack(x):
global cache global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache: if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
x._node.meta['saved_tensor'] += [x] tensor = x._tensor.detach()
cache.add(x._tensor.data_ptr) tensor.uuid = x._tensor.uuid
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.uuid)
return x return x
def unpack(x): def unpack(x):
@ -245,16 +251,25 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# If the output is not a floating point `torch.Tensor` or it does not # If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node. # requires grad, then we should not run backward for this node.
for tensor in normalize_tuple(out): if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
if is_autogradable(tensor) and tensor.requires_grad: grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
phase = Phase.BACKWARD phase = Phase.BACKWARD
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( torch.autograd.backward(
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta')) out,
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True) grad_out,
)
graph_info = autograd_graph_analysis(subgraph) graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
graph_info.fwd_mem_out = activation_size(out)
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
return tensor
return x
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
def unwrap(x): def unwrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x return MetaTensor(x) if isinstance(x, torch.Tensor) else x
@ -279,32 +294,39 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
inplace = kwargs.get('inplace', False)
if inplace:
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
# find the grad for parameter in args and kwargs # find the grad for parameter in args and kwargs
param_size = 0 param_size = 0
def get_param_size(x): def get_param_size(x):
nonlocal param_size
if isinstance(x, Parameter): if isinstance(x, Parameter):
param_size += activation_size(x) param_size += activation_size(x)
tree_map(get_param_size, args) tree_map(get_param_size, args)
tree_map(get_param_size, kwargs) tree_map(get_param_size, kwargs)
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get('inplace', False)
if inplace or target in [torch.nn.functional.relu]:
do_not_cache = True
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if target in [torch.nn.functional.relu]:
meta.fwd_in = []
meta.fwd_tmp = []
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
do_not_cache = False
meta.bwd_mem_out -= param_size meta.bwd_mem_out -= param_size
return out, meta return out, meta
@ -348,25 +370,30 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
inplace = getattr(module, 'inplace', False)
# calculate parameter size # calculate parameter size
param_size = parameter_size(module) param_size = parameter_size(module)
if inplace: # If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
global do_not_cache
inplace = getattr(module, 'inplace', False)
if inplace or type(module) in [torch.nn.ReLU]:
do_not_cache = True
module.inplace = False module.inplace = False
if device == 'meta': if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs) out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_tmp of ReLU to []
# currently we set the fwd_mem_tmp of ReLU to zero if type(module) in [torch.nn.ReLU]:
if type(module) in [torch.nn.modules.activation.ReLU]: meta.fwd_in = []
meta.save_fwd_in = False meta.fwd_tmp = []
meta.bwd_mem_out = 0 meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else: else:
out, meta = _profile_concrete(func, *args, **kwargs) out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
do_not_cache = False
# grad for param will not be counted # grad for param will not be counted
meta.bwd_mem_out -= param_size meta.bwd_mem_out -= param_size

View File

@ -1,13 +1,20 @@
from copy import deepcopy from copy import deepcopy
from typing import Optional, Union, overload from typing import Optional
import torch import torch
from torch.utils._pytree import tree_map, tree_flatten from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device from torch.types import _bool, _dtype, _device
from functools import singledispatchmethod import uuid
from .constant import ALIAS_ATEN
__all__ = ['MetaTensor'] __all__ = ['MetaTensor']
def set_uuid(x):
if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'):
setattr(x, 'uuid', uuid.uuid4())
class MetaTensor(torch.Tensor): class MetaTensor(torch.Tensor):
""" """
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
@ -42,6 +49,7 @@ class MetaTensor(torch.Tensor):
if not r._tensor.is_meta: if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta')) r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta` # only tensor not on `meta` should be copied to `meta`
set_uuid(r._tensor)
return r return r
def __repr__(self): def __repr__(self):
@ -73,6 +81,11 @@ class MetaTensor(torch.Tensor):
# run aten for backend=CPU but actually on backend=Meta # run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs) out = func(*args, **kwargs)
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
setattr(out, 'uuid', args[0].uuid)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in # Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass # our custom tensor subclass
def wrap(x): def wrap(x):
@ -84,7 +97,6 @@ class MetaTensor(torch.Tensor):
return tree_map(wrap, out) return tree_map(wrap, out)
@singledispatchmethod
def to(self, *args, **kwargs) -> torch.Tensor: def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor """An extension of `torch.Tensor.to()` to MetaTensor
@ -101,14 +113,13 @@ class MetaTensor(torch.Tensor):
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
""" """
# this imitates c++ function in the way of @overload # this imitates c++ function in the way of @overload
return super().to(*args, **kwargs) device = None
for arg in args:
@to.register if isinstance(arg, str) or isinstance(arg, _device):
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: device = arg
result = super().to(dtype, non_blocking, copy) if dtype is not None else self if 'device' in kwargs:
return MetaTensor(deepcopy(result), fake_device=device) device = kwargs['device']
result = super().to(*args, **kwargs)
@to.register if device is not None:
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: result = MetaTensor(deepcopy(result), fake_device=device)
result = super().to(dtype, non_blocking, copy) if dtype is not None else self return result
return MetaTensor(deepcopy(result), fake_device=device)

View File

@ -13,6 +13,9 @@ from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import pytest import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor
try: try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -74,7 +77,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
graph = tracer.trace(root=m) graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data) MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
codegen = ActivationCheckpointCodeGen() codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen) gm.graph.set_codegen(codegen)
if solver == solver_rotor: if solver == solver_rotor:
@ -89,7 +92,6 @@ def _run_ckpt_solver(rank):
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@pytest.mark.skip('TODO: refactor ckpt solvers')
def test_ckpt_solver(): def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1) mp.spawn(_run_ckpt_solver, nprocs=1)
@ -111,7 +113,7 @@ def _run_ckpt_solver_torch11(rank):
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
if solver == solver_rotor: if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
else: else:
gm = solver(gm) gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
@ -129,5 +131,5 @@ def test_ckpt_solver_torch11():
if __name__ == '__main__': if __name__ == '__main__':
_run_ckpt_solver(rank=0) _run_ckpt_solver(rank=0)
# test_ckpt_solver() test_ckpt_solver()
# test_ckpt_solver_torch11() test_ckpt_solver_torch11()

View File

@ -1,3 +1,4 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
import torch import torch
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
@ -5,6 +6,9 @@ from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor
try: try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -15,7 +19,7 @@ except:
with_codegen = False with_codegen = False
@pytest.mark.skip(reason='TODO: modify calculations in rotor') @pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize(): def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
@ -26,6 +30,7 @@ def test_linearize():
graph = tracer.trace(model) graph = tracer.trace(model)
graph.set_codegen(ActivationCheckpointCodeGen()) graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu'))
node_list = linearize(gm) node_list = linearize(gm)
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
op_list = gm.__sequence__.list_operations() op_list = gm.__sequence__.list_operations()