mirror of https://github.com/hpcaitech/ColossalAI
[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)
* [fx/profiler] modify data_ptr into uuid for all tensors. * [fx] modify uuid. * [fx/profiler] tune performance on GPT-2. * [fx] updates. * [fx] debug. * [fx] debug. * [fx] cuda.pull/1687/head
parent
0df5034a36
commit
3dd6994427
|
@ -2,6 +2,7 @@ from typing import List, Set, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
import math
|
import math
|
||||||
|
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||||
|
|
||||||
__all__ = ['chen_greedy']
|
__all__ = ['chen_greedy']
|
||||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||||
|
@ -74,10 +75,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||||
prev_idx = 2
|
prev_idx = 2
|
||||||
for (idx, n) in enumerate(gm.graph.nodes):
|
for (idx, n) in enumerate(gm.graph.nodes):
|
||||||
n: Node
|
n: Node
|
||||||
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
|
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
|
||||||
y = max(y, temp)
|
y = max(y, temp)
|
||||||
if temp > b and n in ckpt_nodes:
|
if temp > b and n in ckpt_nodes:
|
||||||
x += n.meta['fwd_mem_out']
|
x += calculate_fwd_in(n)
|
||||||
temp = 0
|
temp = 0
|
||||||
ckpt_intv.append((prev_idx, idx + 1))
|
ckpt_intv.append((prev_idx, idx + 1))
|
||||||
prev_idx = idx + 1
|
prev_idx = idx + 1
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
from colossalai.fx.profiler.memory import calculate_fwd_in
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from .linearize import linearize
|
||||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||||
|
@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
||||||
|
|
||||||
xbar = 0
|
xbar = 0
|
||||||
for n in node:
|
for n in node:
|
||||||
xbar += n.meta['fwd_mem_tmp']
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||||
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
|
|
||||||
xbar += n.meta['fwd_mem_out']
|
|
||||||
return xbar
|
return xbar
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int:
|
||||||
return bwd_time
|
return bwd_time
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fwd_mem_tmp(node: List[Node]) -> int:
|
||||||
|
"""Get the forward temp memory of a node
|
||||||
|
This could be done by subtracting the saved activation from all output of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): List of torch.fx Node,
|
||||||
|
indicates a node in linearized graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: forward temp memory, unit Byte
|
||||||
|
"""
|
||||||
|
n = node[-1]
|
||||||
|
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||||
|
|
||||||
|
|
||||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||||
"""Get the backward temp memory of a node
|
"""Get the backward temp memory of a node
|
||||||
|
|
||||||
|
@ -184,9 +198,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||||
if v > 0:
|
if v > 0:
|
||||||
deps_size += k.meta['bwd_mem_out']
|
deps_size += k.meta['bwd_mem_out']
|
||||||
if v == float('-inf'):
|
if v == float('-inf'):
|
||||||
deps_size -= k.meta['fwd_mem_tmp']
|
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||||
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
|
|
||||||
deps_size -= k.meta['fwd_mem_out']
|
|
||||||
|
|
||||||
return deps_size
|
return deps_size
|
||||||
|
|
||||||
|
@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain:
|
||||||
bwd_time = []
|
bwd_time = []
|
||||||
xbar_sizes = [activation_size(input)]
|
xbar_sizes = [activation_size(input)]
|
||||||
x_sizes = [activation_size(input)]
|
x_sizes = [activation_size(input)]
|
||||||
# currently we can't get the temp memory needed in fwd
|
tmp_fwd = []
|
||||||
tmp_fwd = [0] * len(node_list)
|
|
||||||
tmp_bwd = []
|
tmp_bwd = []
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
fwd_time.append(_fwd_time(node))
|
fwd_time.append(_fwd_time(node))
|
||||||
bwd_time.append(_bwd_time(node))
|
bwd_time.append(_bwd_time(node))
|
||||||
x_sizes.append(node[-1].meta['fwd_mem_out'])
|
x_sizes.append(calculate_fwd_out(node[-1]))
|
||||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||||
|
tmp_fwd.append(_get_fwd_mem_tmp(node))
|
||||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||||
|
|
||||||
bwd_time.append(0)
|
bwd_time.append(0)
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from colossalai.fx.profiler import GraphInfo
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, Argument, Target
|
from torch.fx.node import Node, Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from typing import Any, List, Tuple, NamedTuple, Dict
|
from typing import Any, List, Tuple, NamedTuple, Dict
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
|
|
||||||
# output of above code is
|
# output of above code is
|
||||||
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||||
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
|
----------- ------- --------------- ---------------- --------- --------- --------- ---------
|
||||||
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||||
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
|
||||||
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
|
||||||
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||||
Args:
|
Args:
|
||||||
module (GraphModule): The module to be executed
|
module (GraphModule): The module to be executed
|
||||||
|
|
||||||
|
@ -102,7 +101,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
n.meta['tensor_meta'] = tensor_meta
|
n.meta['tensor_meta'] = tensor_meta
|
||||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||||
# TODO: the attribute node_size should be removed in the future
|
# TODO: the attribute node_size should be removed in the future
|
||||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||||
n.meta['type'] = type(result)
|
n.meta['type'] = type(result)
|
||||||
|
|
||||||
# retain the autograd graph
|
# retain the autograd graph
|
||||||
|
@ -228,6 +227,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
result (Any): The argument value that was retrieved
|
result (Any): The argument value that was retrieved
|
||||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(args[0], '_tensor'):
|
||||||
|
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
|
||||||
return args[0], GraphInfo(save_fwd_in=True)
|
return args[0], GraphInfo(save_fwd_in=True)
|
||||||
|
|
||||||
def propagate(self, *args):
|
def propagate(self, *args):
|
||||||
|
@ -281,9 +282,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
str(node),
|
str(node),
|
||||||
flops_repr(node.meta['fwd_flop']),
|
flops_repr(node.meta['fwd_flop']),
|
||||||
flops_repr(node.meta['bwd_flop']),
|
flops_repr(node.meta['bwd_flop']),
|
||||||
node.meta['save_fwd_in'],
|
mem_repr(calculate_fwd_in(node)),
|
||||||
mem_repr(node.meta['fwd_mem_out']),
|
mem_repr(calculate_fwd_out(node)),
|
||||||
mem_repr(node.meta['fwd_mem_tmp']),
|
mem_repr(calculate_fwd_tmp(node)),
|
||||||
mem_repr(node.meta['bwd_mem_out']),
|
mem_repr(node.meta['bwd_mem_out']),
|
||||||
mem_repr(node.meta['bwd_mem_tmp']),
|
mem_repr(node.meta['bwd_mem_tmp']),
|
||||||
])
|
])
|
||||||
|
@ -295,7 +296,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
'Op',
|
'Op',
|
||||||
'Forward FLOPs',
|
'Forward FLOPs',
|
||||||
'Backward FLOPs',
|
'Backward FLOPs',
|
||||||
'SAVE_FWD_IN',
|
'FWD_IN',
|
||||||
'FWD_OUT',
|
'FWD_OUT',
|
||||||
'FWD_TMP',
|
'FWD_TMP',
|
||||||
'BWD_OUT',
|
'BWD_OUT',
|
||||||
|
|
|
@ -3,8 +3,9 @@ if META_COMPATIBILITY:
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
|
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||||
else:
|
else:
|
||||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||||
|
|
||||||
from .dataflow import GraphInfo
|
from .dataflow import GraphInfo
|
||||||
from .memory import parameter_size, activation_size, is_inplace
|
from .memory import parameter_size, activation_size, is_inplace
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from .memory import activation_size, is_inplace
|
from .memory import activation_size, is_inplace
|
||||||
|
|
||||||
|
@ -39,16 +39,25 @@ class GraphInfo:
|
||||||
bwd_flop (int): The backward FLOPs of a certain node.
|
bwd_flop (int): The backward FLOPs of a certain node.
|
||||||
bwd_time (float): The real backward time (s) of a certain node.
|
bwd_time (float): The real backward time (s) of a certain node.
|
||||||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||||
|
fwd_in (List): See the above illustration.
|
||||||
|
fwd_tmp (List): See the above illustration.
|
||||||
|
fwd_out (List): See the above illustration.
|
||||||
fwd_mem_tmp (int): See the above illustration.
|
fwd_mem_tmp (int): See the above illustration.
|
||||||
fwd_mem_out (int): See the above illustration.
|
fwd_mem_out (int): See the above illustration.
|
||||||
bwd_mem_tmp (int): See the above illustration.
|
bwd_mem_tmp (int): See the above illustration.
|
||||||
bwd_mem_out (int): See the above illustration.
|
bwd_mem_out (int): See the above illustration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
|
||||||
|
|
||||||
fwd_flop: int = 0
|
fwd_flop: int = 0
|
||||||
fwd_time: float = 0.0
|
fwd_time: float = 0.0
|
||||||
bwd_flop: int = 0
|
bwd_flop: int = 0
|
||||||
bwd_time: float = 0.0
|
bwd_time: float = 0.0
|
||||||
save_fwd_in: bool = False
|
save_fwd_in: bool = False
|
||||||
|
fwd_in: List = field(default_factory=list)
|
||||||
|
fwd_tmp: List = field(default_factory=list)
|
||||||
|
fwd_out: List = field(default_factory=list)
|
||||||
fwd_mem_tmp: int = 0
|
fwd_mem_tmp: int = 0
|
||||||
fwd_mem_out: int = 0
|
fwd_mem_out: int = 0
|
||||||
bwd_mem_tmp: int = 0
|
bwd_mem_tmp: int = 0
|
||||||
|
@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||||
return n.meta['phase'] == phase
|
return n.meta['phase'] == phase
|
||||||
|
|
||||||
|
|
||||||
def is_saved(n: Node):
|
|
||||||
return len(n.meta['saved_tensor'])
|
|
||||||
|
|
||||||
|
|
||||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||||
|
@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||||
# the node, `fwd_mem_tmp` can be freed.
|
# the node, `fwd_mem_tmp` can be freed.
|
||||||
if is_phase(n, Phase.PLACEHOLDER):
|
if is_phase(n, Phase.PLACEHOLDER):
|
||||||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
graph_info.fwd_in += n.meta['saved_tensor']
|
||||||
if is_phase(n, Phase.FORWARD):
|
if is_phase(n, Phase.FORWARD):
|
||||||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
graph_info.fwd_tmp += n.meta['saved_tensor']
|
||||||
elif is_phase(n, Phase.BACKWARD):
|
elif is_phase(n, Phase.BACKWARD):
|
||||||
if len(n.users):
|
if len(n.users):
|
||||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .registry import meta_profiler_function, meta_profiler_module
|
from .registry import meta_profiler_function, meta_profiler_module
|
||||||
|
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||||
from .profiler_function import *
|
from .profiler_function import *
|
||||||
from .profiler_module import *
|
from .profiler_module import *
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# for PyTorch 1.11 compatibility uses
|
||||||
|
import torch
|
||||||
|
from torch.fx import Node, GraphModule
|
||||||
|
from typing import Union, Dict, List, Tuple
|
||||||
|
|
||||||
|
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_in(n: Node) -> bool:
|
||||||
|
"""A helper function to calculate `fwd_in`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
save_fwd_in (bool): the result of `save_fwd_in`
|
||||||
|
"""
|
||||||
|
return n.meta['save_fwd_in']
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_tmp(n: Node) -> int:
|
||||||
|
"""A helper function to calculate `fwd_tmp`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fwd_tmp (int): the result of `fwd_tmp`
|
||||||
|
"""
|
||||||
|
return n.meta["fwd_mem_tmp"]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_out(n: Node) -> int:
|
||||||
|
"""A helper function to calculate `fwd_out`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fwd_out (int): the result of `fwd_out`
|
||||||
|
"""
|
||||||
|
return n.meta['fwd_mem_out']
|
|
@ -1,9 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node
|
from torch.fx import Node, GraphModule
|
||||||
from typing import Union, Dict, List, Tuple
|
from typing import Union, Dict, List, Tuple
|
||||||
from . import META_COMPATIBILITY
|
from . import META_COMPATIBILITY
|
||||||
|
|
||||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
__all__ = [
|
||||||
|
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
|
@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
elif isinstance(out, dict):
|
elif isinstance(out, dict):
|
||||||
value_list = [v for _, v in out.items()]
|
value_list = [v for _, v in out.items()]
|
||||||
act_size += activation_size(value_list)
|
act_size += activation_size(value_list)
|
||||||
elif isinstance(out, tuple) or isinstance(out, list):
|
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
||||||
for element in out:
|
for element in out:
|
||||||
act_size += activation_size(element)
|
act_size += activation_size(element)
|
||||||
return act_size
|
return act_size
|
||||||
|
@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
||||||
return param_size
|
return param_size
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_in(n: Node) -> int:
|
||||||
|
"""A helper function to calculate `fwd_in`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fwd_in (int): the result of `fwd_in`
|
||||||
|
"""
|
||||||
|
return activation_size(n.meta["fwd_in"])
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_tmp(n: Node) -> int:
|
||||||
|
"""A helper function to calculate `fwd_tmp`
|
||||||
|
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fwd_tmp (int): the result of `fwd_tmp`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_relu_node(n: Node) -> bool:
|
||||||
|
if n.op == 'call_function':
|
||||||
|
return n.target in [torch.nn.functional.relu]
|
||||||
|
elif n.op == 'call_module':
|
||||||
|
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_relu_node(n):
|
||||||
|
return activation_size(n.meta["fwd_tmp"])
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fwd_out(n: Node) -> int:
|
||||||
|
"""A helper function to calculate `fwd_out`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fwd_out (int): the result of `fwd_out`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def intersect(a, b):
|
||||||
|
return {k: a[k] for k in a if k in b}
|
||||||
|
|
||||||
|
fwd_in = dict()
|
||||||
|
for u in n.users:
|
||||||
|
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
||||||
|
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||||
|
return activation_size(intersect(fwd_in, fwd_out))
|
||||||
|
|
||||||
|
|
||||||
def is_inplace(n: Node):
|
def is_inplace(n: Node):
|
||||||
"""Get the inplace argument from torch.fx.Node
|
"""Get the inplace argument from torch.fx.Node
|
||||||
|
|
||||||
|
|
|
@ -226,6 +226,7 @@ flop_mapping = {
|
||||||
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||||
|
aten.embedding.default: elementwise_flop_counter(1, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
elementwise_flop_aten = [
|
elementwise_flop_aten = [
|
||||||
|
@ -304,10 +305,12 @@ zero_flop_aten = [
|
||||||
aten.transpose.int,
|
aten.transpose.int,
|
||||||
aten._to_copy.default,
|
aten._to_copy.default,
|
||||||
aten.unsqueeze.default,
|
aten.unsqueeze.default,
|
||||||
|
aten.unbind.int,
|
||||||
aten._unsafe_view.default,
|
aten._unsafe_view.default,
|
||||||
aten.view.default,
|
aten.view.default,
|
||||||
aten.where.self,
|
aten.where.self,
|
||||||
aten.zero_.default,
|
aten.zero_.default,
|
||||||
|
aten.zeros_like.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
for op in zero_flop_aten:
|
for op in zero_flop_aten:
|
||||||
|
|
|
@ -18,6 +18,9 @@ __all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
# track duplicated tensors between nodes
|
# track duplicated tensors between nodes
|
||||||
cache = set()
|
cache = set()
|
||||||
|
|
||||||
|
# a global identifier for inplace ops
|
||||||
|
do_not_cache = False
|
||||||
|
|
||||||
|
|
||||||
def normalize_tuple(x):
|
def normalize_tuple(x):
|
||||||
if not isinstance(x, tuple):
|
if not isinstance(x, tuple):
|
||||||
|
@ -223,10 +226,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||||
kwargs = tree_map(wrap, kwargs)
|
kwargs = tree_map(wrap, kwargs)
|
||||||
|
|
||||||
def pack(x):
|
def pack(x):
|
||||||
global cache
|
global cache, do_not_cache
|
||||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
|
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
|
||||||
x._node.meta['saved_tensor'] += [x]
|
tensor = x._tensor.detach()
|
||||||
cache.add(x._tensor.data_ptr)
|
tensor.uuid = x._tensor.uuid
|
||||||
|
x._node.meta['saved_tensor'] += [tensor]
|
||||||
|
if not do_not_cache:
|
||||||
|
cache.add(x._tensor.uuid)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def unpack(x):
|
def unpack(x):
|
||||||
|
@ -245,16 +251,25 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||||
|
|
||||||
# If the output is not a floating point `torch.Tensor` or it does not
|
# If the output is not a floating point `torch.Tensor` or it does not
|
||||||
# requires grad, then we should not run backward for this node.
|
# requires grad, then we should not run backward for this node.
|
||||||
for tensor in normalize_tuple(out):
|
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
|
||||||
if is_autogradable(tensor) and tensor.requires_grad:
|
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
|
||||||
phase = Phase.BACKWARD
|
phase = Phase.BACKWARD
|
||||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
torch.autograd.backward(
|
||||||
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
|
out,
|
||||||
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
|
grad_out,
|
||||||
|
)
|
||||||
|
|
||||||
graph_info = autograd_graph_analysis(subgraph)
|
graph_info = autograd_graph_analysis(subgraph)
|
||||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||||
graph_info.fwd_mem_out = activation_size(out)
|
|
||||||
|
def extract_tensor(x: Any):
|
||||||
|
if isinstance(x, MetaTensor):
|
||||||
|
tensor = x._tensor.detach()
|
||||||
|
tensor.uuid = x._tensor.uuid
|
||||||
|
return tensor
|
||||||
|
return x
|
||||||
|
|
||||||
|
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
|
||||||
|
|
||||||
def unwrap(x):
|
def unwrap(x):
|
||||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||||
|
@ -279,32 +294,39 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
|
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
|
||||||
# If there is an argument that this `call_function` is inplace, we should
|
|
||||||
# still run the profiling but discard some results regarding `target`
|
|
||||||
inplace = kwargs.get('inplace', False)
|
|
||||||
if inplace:
|
|
||||||
kwargs['inplace'] = False
|
|
||||||
if device == 'meta':
|
|
||||||
out, meta = _profile_meta(func, *args, **kwargs)
|
|
||||||
|
|
||||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
|
||||||
if target in [torch.nn.functional.relu]:
|
|
||||||
meta.save_fwd_in = False
|
|
||||||
meta.bwd_mem_out = 0
|
|
||||||
meta.fwd_mem_tmp = 0
|
|
||||||
else:
|
|
||||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
|
||||||
|
|
||||||
# find the grad for parameter in args and kwargs
|
# find the grad for parameter in args and kwargs
|
||||||
param_size = 0
|
param_size = 0
|
||||||
|
|
||||||
def get_param_size(x):
|
def get_param_size(x):
|
||||||
|
nonlocal param_size
|
||||||
if isinstance(x, Parameter):
|
if isinstance(x, Parameter):
|
||||||
param_size += activation_size(x)
|
param_size += activation_size(x)
|
||||||
|
|
||||||
tree_map(get_param_size, args)
|
tree_map(get_param_size, args)
|
||||||
tree_map(get_param_size, kwargs)
|
tree_map(get_param_size, kwargs)
|
||||||
|
|
||||||
|
# If there is an argument that this `call_function` is inplace, we should
|
||||||
|
# still run the profiling but discard some results regarding `target`
|
||||||
|
global do_not_cache
|
||||||
|
inplace = kwargs.get('inplace', False)
|
||||||
|
if inplace or target in [torch.nn.functional.relu]:
|
||||||
|
do_not_cache = True
|
||||||
|
kwargs['inplace'] = False
|
||||||
|
if device == 'meta':
|
||||||
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
|
# currently we set the fwd_mem_tmp of ReLU to zero
|
||||||
|
if target in [torch.nn.functional.relu]:
|
||||||
|
meta.fwd_in = []
|
||||||
|
meta.fwd_tmp = []
|
||||||
|
meta.bwd_mem_out = 0
|
||||||
|
meta.fwd_mem_tmp = 0
|
||||||
|
else:
|
||||||
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
kwargs['inplace'] = True
|
||||||
|
do_not_cache = False
|
||||||
|
|
||||||
meta.bwd_mem_out -= param_size
|
meta.bwd_mem_out -= param_size
|
||||||
return out, meta
|
return out, meta
|
||||||
|
|
||||||
|
@ -348,25 +370,30 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||||
|
|
||||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
|
||||||
# If there is an argument that this `call_module` is inplace, we should
|
|
||||||
# still run the profiling but discard some results regarding `module`.
|
|
||||||
inplace = getattr(module, 'inplace', False)
|
|
||||||
|
|
||||||
# calculate parameter size
|
# calculate parameter size
|
||||||
param_size = parameter_size(module)
|
param_size = parameter_size(module)
|
||||||
|
|
||||||
if inplace:
|
# If there is an argument that this `call_module` is inplace, we should
|
||||||
|
# still run the profiling but discard some results regarding `module`.
|
||||||
|
global do_not_cache
|
||||||
|
|
||||||
|
inplace = getattr(module, 'inplace', False)
|
||||||
|
if inplace or type(module) in [torch.nn.ReLU]:
|
||||||
|
do_not_cache = True
|
||||||
module.inplace = False
|
module.inplace = False
|
||||||
if device == 'meta':
|
if device == 'meta':
|
||||||
out, meta = _profile_meta(func, *args, **kwargs)
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
|
# currently we set the fwd_tmp of ReLU to []
|
||||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
if type(module) in [torch.nn.ReLU]:
|
||||||
if type(module) in [torch.nn.modules.activation.ReLU]:
|
meta.fwd_in = []
|
||||||
meta.save_fwd_in = False
|
meta.fwd_tmp = []
|
||||||
meta.bwd_mem_out = 0
|
meta.bwd_mem_out = 0
|
||||||
meta.fwd_mem_tmp = 0
|
|
||||||
else:
|
else:
|
||||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
|
if inplace:
|
||||||
|
|
||||||
|
module.inplace = True
|
||||||
|
do_not_cache = False
|
||||||
|
|
||||||
# grad for param will not be counted
|
# grad for param will not be counted
|
||||||
meta.bwd_mem_out -= param_size
|
meta.bwd_mem_out -= param_size
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, Union, overload
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map, tree_flatten
|
from torch.utils._pytree import tree_map, tree_flatten
|
||||||
from torch.types import _bool, _dtype, _device
|
from torch.types import _bool, _dtype, _device
|
||||||
from functools import singledispatchmethod
|
import uuid
|
||||||
|
from .constant import ALIAS_ATEN
|
||||||
|
|
||||||
__all__ = ['MetaTensor']
|
__all__ = ['MetaTensor']
|
||||||
|
|
||||||
|
|
||||||
|
def set_uuid(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
if not hasattr(x, 'uuid'):
|
||||||
|
setattr(x, 'uuid', uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
class MetaTensor(torch.Tensor):
|
class MetaTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||||
|
@ -42,6 +49,7 @@ class MetaTensor(torch.Tensor):
|
||||||
if not r._tensor.is_meta:
|
if not r._tensor.is_meta:
|
||||||
r._tensor = r._tensor.to(torch.device('meta'))
|
r._tensor = r._tensor.to(torch.device('meta'))
|
||||||
# only tensor not on `meta` should be copied to `meta`
|
# only tensor not on `meta` should be copied to `meta`
|
||||||
|
set_uuid(r._tensor)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -73,6 +81,11 @@ class MetaTensor(torch.Tensor):
|
||||||
# run aten for backend=CPU but actually on backend=Meta
|
# run aten for backend=CPU but actually on backend=Meta
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
|
||||||
|
# of the input
|
||||||
|
if func in ALIAS_ATEN:
|
||||||
|
setattr(out, 'uuid', args[0].uuid)
|
||||||
|
|
||||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||||
# our custom tensor subclass
|
# our custom tensor subclass
|
||||||
def wrap(x):
|
def wrap(x):
|
||||||
|
@ -84,7 +97,6 @@ class MetaTensor(torch.Tensor):
|
||||||
|
|
||||||
return tree_map(wrap, out)
|
return tree_map(wrap, out)
|
||||||
|
|
||||||
@singledispatchmethod
|
|
||||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||||
"""An extension of `torch.Tensor.to()` to MetaTensor
|
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||||
|
|
||||||
|
@ -101,14 +113,13 @@ class MetaTensor(torch.Tensor):
|
||||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||||
"""
|
"""
|
||||||
# this imitates c++ function in the way of @overload
|
# this imitates c++ function in the way of @overload
|
||||||
return super().to(*args, **kwargs)
|
device = None
|
||||||
|
for arg in args:
|
||||||
@to.register
|
if isinstance(arg, str) or isinstance(arg, _device):
|
||||||
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
device = arg
|
||||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
if 'device' in kwargs:
|
||||||
return MetaTensor(deepcopy(result), fake_device=device)
|
device = kwargs['device']
|
||||||
|
result = super().to(*args, **kwargs)
|
||||||
@to.register
|
if device is not None:
|
||||||
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
result = MetaTensor(deepcopy(result), fake_device=device)
|
||||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
return result
|
||||||
return MetaTensor(deepcopy(result), fake_device=device)
|
|
||||||
|
|
|
@ -13,6 +13,9 @@ from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
import pytest
|
import pytest
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
@ -74,7 +77,7 @@ def _run_ckpt_solver(rank):
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
graph = tracer.trace(root=m)
|
graph = tracer.trace(root=m)
|
||||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
|
||||||
codegen = ActivationCheckpointCodeGen()
|
codegen = ActivationCheckpointCodeGen()
|
||||||
gm.graph.set_codegen(codegen)
|
gm.graph.set_codegen(codegen)
|
||||||
if solver == solver_rotor:
|
if solver == solver_rotor:
|
||||||
|
@ -89,7 +92,6 @@ def _run_ckpt_solver(rank):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||||
@pytest.mark.skip('TODO: refactor ckpt solvers')
|
|
||||||
def test_ckpt_solver():
|
def test_ckpt_solver():
|
||||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||||
|
|
||||||
|
@ -111,7 +113,7 @@ def _run_ckpt_solver_torch11(rank):
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||||
if solver == solver_rotor:
|
if solver == solver_rotor:
|
||||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
|
||||||
else:
|
else:
|
||||||
gm = solver(gm)
|
gm = solver(gm)
|
||||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||||
|
@ -129,5 +131,5 @@ def test_ckpt_solver_torch11():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_run_ckpt_solver(rank=0)
|
_run_ckpt_solver(rank=0)
|
||||||
# test_ckpt_solver()
|
test_ckpt_solver()
|
||||||
# test_ckpt_solver_torch11()
|
test_ckpt_solver_torch11()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
|
@ -5,6 +6,9 @@ from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||||
import pytest
|
import pytest
|
||||||
|
from colossalai import META_COMPATIBILITY
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
@ -15,7 +19,7 @@ except:
|
||||||
with_codegen = False
|
with_codegen = False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='TODO: modify calculations in rotor')
|
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||||
def test_linearize():
|
def test_linearize():
|
||||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||||
|
@ -26,6 +30,7 @@ def test_linearize():
|
||||||
graph = tracer.trace(model)
|
graph = tracer.trace(model)
|
||||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||||
|
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu'))
|
||||||
node_list = linearize(gm)
|
node_list = linearize(gm)
|
||||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||||
op_list = gm.__sequence__.list_operations()
|
op_list = gm.__sequence__.list_operations()
|
||||||
|
|
Loading…
Reference in New Issue