[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)

* [fx/profiler] modify data_ptr into uuid for all tensors.

* [fx] modify uuid.

* [fx/profiler] tune performance on GPT-2.

* [fx] updates.

* [fx] debug.

* [fx] debug.

* [fx] cuda.
pull/1687/head
Super Daniel 2022-10-11 11:03:35 +08:00 committed by GitHub
parent 0df5034a36
commit 3dd6994427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 262 additions and 94 deletions

View File

@ -2,6 +2,7 @@ from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
@ -74,10 +75,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += n.meta['fwd_mem_out']
x += calculate_fwd_in(n)
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1

View File

@ -1,8 +1,9 @@
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0
for n in node:
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
return xbar
@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int:
return bwd_time
def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
@ -184,9 +198,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain:
bwd_time = []
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_fwd = []
tmp_bwd = []
for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(node[-1].meta['fwd_mem_out'])
x_sizes.append(calculate_fwd_out(node[-1]))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_fwd.append(_get_fwd_mem_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))
bwd_time.append(0)

View File

@ -1,12 +1,11 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
@compatibility(is_backward_compatible=True)
@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter):
# output of above code is
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed
@ -102,7 +101,7 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
n.meta['type'] = type(result)
# retain the autograd graph
@ -228,6 +227,8 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
if hasattr(args[0], '_tensor'):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
@ -281,9 +282,9 @@ class MetaInfoProp(torch.fx.Interpreter):
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
@ -295,7 +296,7 @@ class MetaInfoProp(torch.fx.Interpreter):
'Op',
'Forward FLOPs',
'Backward FLOPs',
'SAVE_FWD_IN',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',

View File

@ -3,8 +3,9 @@ if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict
from typing import Dict, List
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace
@ -39,16 +39,25 @@ class GraphInfo:
bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_in (List): See the above illustration.
fwd_tmp (List): See the above illustration.
fwd_out (List): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase
def is_saved(n: Node):
return len(n.meta['saved_tensor'])
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.
@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
graph_info.fwd_in += n.meta['saved_tensor']
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
graph_info.fwd_tmp += n.meta['saved_tensor']
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))

View File

@ -1,4 +1,5 @@
from .registry import meta_profiler_function, meta_profiler_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module

View File

@ -0,0 +1,42 @@
# for PyTorch 1.11 compatibility uses
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']

View File

@ -1,9 +1,11 @@
import torch
from torch.fx import Node
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list):
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int:
return param_size
def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
fwd_in (int): the result of `fwd_in`
"""
return activation_size(n.meta["fwd_in"])
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
def is_relu_node(n: Node) -> bool:
if n.op == 'call_function':
return n.target in [torch.nn.functional.relu]
elif n.op == 'call_module':
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
return False
if not is_relu_node(n):
return activation_size(n.meta["fwd_tmp"])
return 0
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
def intersect(a, b):
return {k: a[k] for k in a if k in b}
fwd_in = dict()
for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node

View File

@ -226,6 +226,7 @@ flop_mapping = {
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
}
elementwise_flop_aten = [
@ -304,10 +305,12 @@ zero_flop_aten = [
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
]
for op in zero_flop_aten:

View File

@ -18,6 +18,9 @@ __all__ = ['profile_function', 'profile_module', 'profile_method']
# track duplicated tensors between nodes
cache = set()
# a global identifier for inplace ops
do_not_cache = False
def normalize_tuple(x):
if not isinstance(x, tuple):
@ -223,10 +226,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
kwargs = tree_map(wrap, kwargs)
def pack(x):
global cache
if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache:
x._node.meta['saved_tensor'] += [x]
cache.add(x._tensor.data_ptr)
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.uuid)
return x
def unpack(x):
@ -245,16 +251,25 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
phase = Phase.BACKWARD
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True)
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
phase = Phase.BACKWARD
torch.autograd.backward(
out,
grad_out,
)
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
graph_info.fwd_mem_out = activation_size(out)
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
return tensor
return x
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
def unwrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
@ -279,32 +294,39 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
inplace = kwargs.get('inplace', False)
if inplace:
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if target in [torch.nn.functional.relu]:
meta.save_fwd_in = False
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
# find the grad for parameter in args and kwargs
param_size = 0
def get_param_size(x):
nonlocal param_size
if isinstance(x, Parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)
tree_map(get_param_size, kwargs)
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get('inplace', False)
if inplace or target in [torch.nn.functional.relu]:
do_not_cache = True
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if target in [torch.nn.functional.relu]:
meta.fwd_in = []
meta.fwd_tmp = []
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
do_not_cache = False
meta.bwd_mem_out -= param_size
return out, meta
@ -348,25 +370,30 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
inplace = getattr(module, 'inplace', False)
# calculate parameter size
param_size = parameter_size(module)
if inplace:
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
global do_not_cache
inplace = getattr(module, 'inplace', False)
if inplace or type(module) in [torch.nn.ReLU]:
do_not_cache = True
module.inplace = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if type(module) in [torch.nn.modules.activation.ReLU]:
meta.save_fwd_in = False
# currently we set the fwd_tmp of ReLU to []
if type(module) in [torch.nn.ReLU]:
meta.fwd_in = []
meta.fwd_tmp = []
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
do_not_cache = False
# grad for param will not be counted
meta.bwd_mem_out -= param_size

View File

@ -1,13 +1,20 @@
from copy import deepcopy
from typing import Optional, Union, overload
from typing import Optional
import torch
from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device
from functools import singledispatchmethod
import uuid
from .constant import ALIAS_ATEN
__all__ = ['MetaTensor']
def set_uuid(x):
if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'):
setattr(x, 'uuid', uuid.uuid4())
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
@ -42,6 +49,7 @@ class MetaTensor(torch.Tensor):
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
set_uuid(r._tensor)
return r
def __repr__(self):
@ -73,6 +81,11 @@ class MetaTensor(torch.Tensor):
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
setattr(out, 'uuid', args[0].uuid)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
@ -84,7 +97,6 @@ class MetaTensor(torch.Tensor):
return tree_map(wrap, out)
@singledispatchmethod
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
@ -101,14 +113,13 @@ class MetaTensor(torch.Tensor):
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
return super().to(*args, **kwargs)
@to.register
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)
@to.register
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, _device):
device = arg
if 'device' in kwargs:
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(deepcopy(result), fake_device=device)
return result

View File

@ -13,6 +13,9 @@ from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -74,7 +77,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:
@ -89,7 +92,6 @@ def _run_ckpt_solver(rank):
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@pytest.mark.skip('TODO: refactor ckpt solvers')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
@ -111,7 +113,7 @@ def _run_ckpt_solver_torch11(rank):
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
else:
gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
@ -129,5 +131,5 @@ def test_ckpt_solver_torch11():
if __name__ == '__main__':
_run_ckpt_solver(rank=0)
# test_ckpt_solver()
# test_ckpt_solver_torch11()
test_ckpt_solver()
test_ckpt_solver_torch11()

View File

@ -1,3 +1,4 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
@ -5,6 +6,9 @@ from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -15,7 +19,7 @@ except:
with_codegen = False
@pytest.mark.skip(reason='TODO: modify calculations in rotor')
@pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
@ -26,6 +30,7 @@ def test_linearize():
graph = tracer.trace(model)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu'))
node_list = linearize(gm)
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
op_list = gm.__sequence__.list_operations()