mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)
* [fx/profiler] modify data_ptr into uuid for all tensors. * [fx] modify uuid. * [fx/profiler] tune performance on GPT-2. * [fx] updates. * [fx] debug. * [fx] debug. * [fx] cuda.pull/1687/head
parent
0df5034a36
commit
3dd6994427
|
@ -2,6 +2,7 @@ from typing import List, Set, Tuple
|
|||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
__all__ = ['chen_greedy']
|
||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
|
@ -74,10 +75,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||
prev_idx = 2
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
n: Node
|
||||
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
|
||||
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
||||
y = max(y, temp)
|
||||
if temp > b and n in ckpt_nodes:
|
||||
x += n.meta['fwd_mem_out']
|
||||
x += calculate_fwd_in(n)
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import sys
|
||||
from typing import List, Tuple
|
||||
from colossalai.fx.profiler.memory import calculate_fwd_in
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
|
@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
|||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
return xbar
|
||||
|
||||
|
||||
|
@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int:
|
|||
return bwd_time
|
||||
|
||||
|
||||
def _get_fwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the forward temp memory of a node
|
||||
This could be done by subtracting the saved activation from all output of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
int: forward temp memory, unit Byte
|
||||
"""
|
||||
n = node[-1]
|
||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||
|
||||
|
||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the backward temp memory of a node
|
||||
|
||||
|
@ -184,9 +198,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
|||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= k.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
|
||||
deps_size -= k.meta['fwd_mem_out']
|
||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||
|
||||
return deps_size
|
||||
|
||||
|
@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain:
|
|||
bwd_time = []
|
||||
xbar_sizes = [activation_size(input)]
|
||||
x_sizes = [activation_size(input)]
|
||||
# currently we can't get the temp memory needed in fwd
|
||||
tmp_fwd = [0] * len(node_list)
|
||||
tmp_fwd = []
|
||||
tmp_bwd = []
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
fwd_time.append(_fwd_time(node))
|
||||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(node[-1].meta['fwd_mem_out'])
|
||||
x_sizes.append(calculate_fwd_out(node[-1]))
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_fwd.append(_get_fwd_mem_tmp(node))
|
||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||
|
||||
bwd_time.append(0)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from dataclasses import asdict
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, List, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
|
||||
# output of above code is
|
||||
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
||||
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
||||
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- --------------- ---------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
||||
call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
||||
output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
|
@ -102,7 +101,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
|
@ -228,6 +227,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
result (Any): The argument value that was retrieved
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
if hasattr(args[0], '_tensor'):
|
||||
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
|
||||
return args[0], GraphInfo(save_fwd_in=True)
|
||||
|
||||
def propagate(self, *args):
|
||||
|
@ -281,9 +282,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
str(node),
|
||||
flops_repr(node.meta['fwd_flop']),
|
||||
flops_repr(node.meta['bwd_flop']),
|
||||
node.meta['save_fwd_in'],
|
||||
mem_repr(node.meta['fwd_mem_out']),
|
||||
mem_repr(node.meta['fwd_mem_tmp']),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
|
@ -295,7 +296,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
'Op',
|
||||
'Forward FLOPs',
|
||||
'Backward FLOPs',
|
||||
'SAVE_FWD_IN',
|
||||
'FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
|
|
|
@ -3,8 +3,9 @@ if META_COMPATIBILITY:
|
|||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
else:
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
|
||||
from .dataflow import GraphInfo
|
||||
from .memory import parameter_size, activation_size, is_inplace
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size, is_inplace
|
||||
|
||||
|
@ -39,16 +39,25 @@ class GraphInfo:
|
|||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
bwd_time (float): The real backward time (s) of a certain node.
|
||||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||
fwd_in (List): See the above illustration.
|
||||
fwd_tmp (List): See the above illustration.
|
||||
fwd_out (List): See the above illustration.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
fwd_mem_out (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
|
||||
|
||||
fwd_flop: int = 0
|
||||
fwd_time: float = 0.0
|
||||
bwd_flop: int = 0
|
||||
bwd_time: float = 0.0
|
||||
save_fwd_in: bool = False
|
||||
fwd_in: List = field(default_factory=list)
|
||||
fwd_tmp: List = field(default_factory=list)
|
||||
fwd_out: List = field(default_factory=list)
|
||||
fwd_mem_tmp: int = 0
|
||||
fwd_mem_out: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
|
@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
|||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
return len(n.meta['saved_tensor'])
|
||||
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
|
@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
||||
graph_info.fwd_in += n.meta['saved_tensor']
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
||||
graph_info.fwd_tmp += n.meta['saved_tensor']
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# for PyTorch 1.11 compatibility uses
|
||||
import torch
|
||||
from torch.fx import Node, GraphModule
|
||||
from typing import Union, Dict, List, Tuple
|
||||
|
||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||
|
||||
|
||||
def calculate_fwd_in(n: Node) -> bool:
|
||||
"""A helper function to calculate `fwd_in`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
save_fwd_in (bool): the result of `save_fwd_in`
|
||||
"""
|
||||
return n.meta['save_fwd_in']
|
||||
|
||||
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
return n.meta["fwd_mem_tmp"]
|
||||
|
||||
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_out (int): the result of `fwd_out`
|
||||
"""
|
||||
return n.meta['fwd_mem_out']
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
from torch.fx import Node
|
||||
from torch.fx import Node, GraphModule
|
||||
from typing import Union, Dict, List, Tuple
|
||||
from . import META_COMPATIBILITY
|
||||
|
||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||
__all__ = [
|
||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||
]
|
||||
|
||||
|
||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
|
@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|||
elif isinstance(out, dict):
|
||||
value_list = [v for _, v in out.items()]
|
||||
act_size += activation_size(value_list)
|
||||
elif isinstance(out, tuple) or isinstance(out, list):
|
||||
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
||||
for element in out:
|
||||
act_size += activation_size(element)
|
||||
return act_size
|
||||
|
@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
|||
return param_size
|
||||
|
||||
|
||||
def calculate_fwd_in(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_in`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_in (int): the result of `fwd_in`
|
||||
"""
|
||||
return activation_size(n.meta["fwd_in"])
|
||||
|
||||
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp`
|
||||
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
|
||||
def is_relu_node(n: Node) -> bool:
|
||||
if n.op == 'call_function':
|
||||
return n.target in [torch.nn.functional.relu]
|
||||
elif n.op == 'call_module':
|
||||
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
|
||||
return False
|
||||
|
||||
if not is_relu_node(n):
|
||||
return activation_size(n.meta["fwd_tmp"])
|
||||
return 0
|
||||
|
||||
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_out (int): the result of `fwd_out`
|
||||
"""
|
||||
|
||||
def intersect(a, b):
|
||||
return {k: a[k] for k in a if k in b}
|
||||
|
||||
fwd_in = dict()
|
||||
for u in n.users:
|
||||
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
return activation_size(intersect(fwd_in, fwd_out))
|
||||
|
||||
|
||||
def is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
|
||||
|
|
|
@ -226,6 +226,7 @@ flop_mapping = {
|
|||
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten.embedding.default: elementwise_flop_counter(1, 0),
|
||||
}
|
||||
|
||||
elementwise_flop_aten = [
|
||||
|
@ -304,10 +305,12 @@ zero_flop_aten = [
|
|||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.where.self,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
|
||||
for op in zero_flop_aten:
|
||||
|
|
|
@ -18,6 +18,9 @@ __all__ = ['profile_function', 'profile_module', 'profile_method']
|
|||
# track duplicated tensors between nodes
|
||||
cache = set()
|
||||
|
||||
# a global identifier for inplace ops
|
||||
do_not_cache = False
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
|
@ -223,10 +226,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
def pack(x):
|
||||
global cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
||||
x._node.meta['saved_tensor'] += [x]
|
||||
cache.add(x._tensor.data_ptr)
|
||||
global cache, do_not_cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
x._node.meta['saved_tensor'] += [tensor]
|
||||
if not do_not_cache:
|
||||
cache.add(x._tensor.uuid)
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
|
@ -245,16 +251,25 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
for tensor in normalize_tuple(out):
|
||||
if is_autogradable(tensor) and tensor.requires_grad:
|
||||
phase = Phase.BACKWARD
|
||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
|
||||
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
|
||||
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
|
||||
phase = Phase.BACKWARD
|
||||
torch.autograd.backward(
|
||||
out,
|
||||
grad_out,
|
||||
)
|
||||
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||
graph_info.fwd_mem_out = activation_size(out)
|
||||
|
||||
def extract_tensor(x: Any):
|
||||
if isinstance(x, MetaTensor):
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
return tensor
|
||||
return x
|
||||
|
||||
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
|
||||
|
||||
def unwrap(x):
|
||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
@ -279,32 +294,39 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# still run the profiling but discard some results regarding `target`
|
||||
inplace = kwargs.get('inplace', False)
|
||||
if inplace:
|
||||
kwargs['inplace'] = False
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
|
||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
||||
if target in [torch.nn.functional.relu]:
|
||||
meta.save_fwd_in = False
|
||||
meta.bwd_mem_out = 0
|
||||
meta.fwd_mem_tmp = 0
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
|
||||
# find the grad for parameter in args and kwargs
|
||||
param_size = 0
|
||||
|
||||
def get_param_size(x):
|
||||
nonlocal param_size
|
||||
if isinstance(x, Parameter):
|
||||
param_size += activation_size(x)
|
||||
|
||||
tree_map(get_param_size, args)
|
||||
tree_map(get_param_size, kwargs)
|
||||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# still run the profiling but discard some results regarding `target`
|
||||
global do_not_cache
|
||||
inplace = kwargs.get('inplace', False)
|
||||
if inplace or target in [torch.nn.functional.relu]:
|
||||
do_not_cache = True
|
||||
kwargs['inplace'] = False
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
||||
if target in [torch.nn.functional.relu]:
|
||||
meta.fwd_in = []
|
||||
meta.fwd_tmp = []
|
||||
meta.bwd_mem_out = 0
|
||||
meta.fwd_mem_tmp = 0
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
|
||||
if inplace:
|
||||
kwargs['inplace'] = True
|
||||
do_not_cache = False
|
||||
|
||||
meta.bwd_mem_out -= param_size
|
||||
return out, meta
|
||||
|
||||
|
@ -348,25 +370,30 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
|||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# still run the profiling but discard some results regarding `module`.
|
||||
inplace = getattr(module, 'inplace', False)
|
||||
|
||||
# calculate parameter size
|
||||
param_size = parameter_size(module)
|
||||
|
||||
if inplace:
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# still run the profiling but discard some results regarding `module`.
|
||||
global do_not_cache
|
||||
|
||||
inplace = getattr(module, 'inplace', False)
|
||||
if inplace or type(module) in [torch.nn.ReLU]:
|
||||
do_not_cache = True
|
||||
module.inplace = False
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
|
||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
||||
if type(module) in [torch.nn.modules.activation.ReLU]:
|
||||
meta.save_fwd_in = False
|
||||
# currently we set the fwd_tmp of ReLU to []
|
||||
if type(module) in [torch.nn.ReLU]:
|
||||
meta.fwd_in = []
|
||||
meta.fwd_tmp = []
|
||||
meta.bwd_mem_out = 0
|
||||
meta.fwd_mem_tmp = 0
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
|
||||
module.inplace = True
|
||||
do_not_cache = False
|
||||
|
||||
# grad for param will not be counted
|
||||
meta.bwd_mem_out -= param_size
|
||||
|
|
|
@ -1,13 +1,20 @@
|
|||
from copy import deepcopy
|
||||
from typing import Optional, Union, overload
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.types import _bool, _dtype, _device
|
||||
from functools import singledispatchmethod
|
||||
import uuid
|
||||
from .constant import ALIAS_ATEN
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
||||
def set_uuid(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if not hasattr(x, 'uuid'):
|
||||
setattr(x, 'uuid', uuid.uuid4())
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||
|
@ -42,6 +49,7 @@ class MetaTensor(torch.Tensor):
|
|||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
set_uuid(r._tensor)
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -73,6 +81,11 @@ class MetaTensor(torch.Tensor):
|
|||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
|
||||
# of the input
|
||||
if func in ALIAS_ATEN:
|
||||
setattr(out, 'uuid', args[0].uuid)
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
|
@ -84,7 +97,6 @@ class MetaTensor(torch.Tensor):
|
|||
|
||||
return tree_map(wrap, out)
|
||||
|
||||
@singledispatchmethod
|
||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||
|
||||
|
@ -101,14 +113,13 @@ class MetaTensor(torch.Tensor):
|
|||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
@to.register
|
||||
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
|
||||
@to.register
|
||||
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
device = None
|
||||
for arg in args:
|
||||
if isinstance(arg, str) or isinstance(arg, _device):
|
||||
device = arg
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
result = super().to(*args, **kwargs)
|
||||
if device is not None:
|
||||
result = MetaTensor(deepcopy(result), fake_device=device)
|
||||
return result
|
||||
|
|
|
@ -13,6 +13,9 @@ from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -74,7 +77,7 @@ def _run_ckpt_solver(rank):
|
|||
m = model_cls(num_classes=5)
|
||||
graph = tracer.trace(root=m)
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
if solver == solver_rotor:
|
||||
|
@ -89,7 +92,6 @@ def _run_ckpt_solver(rank):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skip('TODO: refactor ckpt solvers')
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
|
||||
|
@ -111,7 +113,7 @@ def _run_ckpt_solver_torch11(rank):
|
|||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
if solver == solver_rotor:
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
|
||||
else:
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
|
@ -129,5 +131,5 @@ def test_ckpt_solver_torch11():
|
|||
|
||||
if __name__ == '__main__':
|
||||
_run_ckpt_solver(rank=0)
|
||||
# test_ckpt_solver()
|
||||
# test_ckpt_solver_torch11()
|
||||
test_ckpt_solver()
|
||||
test_ckpt_solver_torch11()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
|
@ -5,6 +6,9 @@ from colossalai.fx.graph_module import ColoGraphModule
|
|||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
@ -15,7 +19,7 @@ except:
|
|||
with_codegen = False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='TODO: modify calculations in rotor')
|
||||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
|
@ -26,6 +30,7 @@ def test_linearize():
|
|||
graph = tracer.trace(model)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu'))
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
|
|
Loading…
Reference in New Issue