diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 860804f48..e38ddbdce 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -2,6 +2,7 @@ from typing import List, Set, Tuple import torch from torch.fx import GraphModule, Node import math +from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp __all__ = ['chen_greedy'] CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] @@ -74,10 +75,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule: prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): n: Node - temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp'] + temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) y = max(y, temp) if temp > b and n in ckpt_nodes: - x += n.meta['fwd_mem_out'] + x += calculate_fwd_in(n) temp = 0 ckpt_intv.append((prev_idx, idx + 1)) prev_idx = idx + 1 diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 08789804e..f5d7dad27 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,8 +1,9 @@ import sys from typing import List, Tuple +from colossalai.fx.profiler.memory import calculate_fwd_in from torch.fx import Node from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import activation_size, parameter_size +from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp import math from .linearize import linearize from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function @@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int: xbar = 0 for n in node: - xbar += n.meta['fwd_mem_tmp'] - if any(map(lambda x: x.meta['save_fwd_in'], n.users)): - xbar += n.meta['fwd_mem_out'] + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) return xbar @@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int: return bwd_time +def _get_fwd_mem_tmp(node: List[Node]) -> int: + """Get the forward temp memory of a node + This could be done by subtracting the saved activation from all output of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: forward temp memory, unit Byte + """ + n = node[-1] + return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) + + def _get_bwd_mem_tmp(node: List[Node]) -> int: """Get the backward temp memory of a node @@ -184,9 +198,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int: if v > 0: deps_size += k.meta['bwd_mem_out'] if v == float('-inf'): - deps_size -= k.meta['fwd_mem_tmp'] - if any(map(lambda x: x.meta['save_fwd_in'], k.users)): - deps_size -= k.meta['fwd_mem_out'] + deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) return deps_size @@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain: bwd_time = [] xbar_sizes = [activation_size(input)] x_sizes = [activation_size(input)] - # currently we can't get the temp memory needed in fwd - tmp_fwd = [0] * len(node_list) + tmp_fwd = [] tmp_bwd = [] for idx, node in enumerate(node_list): fwd_time.append(_fwd_time(node)) bwd_time.append(_bwd_time(node)) - x_sizes.append(node[-1].meta['fwd_mem_out']) + x_sizes.append(calculate_fwd_out(node[-1])) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) + tmp_fwd.append(_get_fwd_mem_tmp(node)) tmp_bwd.append(_get_bwd_mem_tmp(node)) bwd_time.append(0) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 403819a29..e7435fa4e 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,12 +1,11 @@ from dataclasses import asdict -from colossalai.fx.profiler import GraphInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map from typing import Any, List, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size +from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in @compatibility(is_backward_compatible=True) @@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter): # output of above code is - Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP - ----------- ------- --------------- ---------------- ------------- --------- --------- --------- --------- - placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB - call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB - call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB - output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB + Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- --------------- ---------------- --------- --------- --------- --------- + placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB + call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB + output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB Args: module (GraphModule): The module to be executed @@ -102,7 +101,7 @@ class MetaInfoProp(torch.fx.Interpreter): n.meta['tensor_meta'] = tensor_meta n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) + setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) n.meta['type'] = type(result) # retain the autograd graph @@ -228,6 +227,8 @@ class MetaInfoProp(torch.fx.Interpreter): result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ + if hasattr(args[0], '_tensor'): + return args[0], GraphInfo(fwd_in=[args[0]._tensor]) return args[0], GraphInfo(save_fwd_in=True) def propagate(self, *args): @@ -281,9 +282,9 @@ class MetaInfoProp(torch.fx.Interpreter): str(node), flops_repr(node.meta['fwd_flop']), flops_repr(node.meta['bwd_flop']), - node.meta['save_fwd_in'], - mem_repr(node.meta['fwd_mem_out']), - mem_repr(node.meta['fwd_mem_tmp']), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), mem_repr(node.meta['bwd_mem_out']), mem_repr(node.meta['bwd_mem_tmp']), ]) @@ -295,7 +296,7 @@ class MetaInfoProp(torch.fx.Interpreter): 'Op', 'Forward FLOPs', 'Backward FLOPs', - 'SAVE_FWD_IN', + 'FWD_IN', 'FWD_OUT', 'FWD_TMP', 'BWD_OUT', diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index be37fea70..fc02e0c46 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -3,8 +3,9 @@ if META_COMPATIBILITY: from .opcount import flop_mapping from .tensor import MetaTensor from .profiler import profile_function, profile_method, profile_module + from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out else: - from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module + from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .dataflow import GraphInfo from .memory import parameter_size, activation_size, is_inplace diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 14d876a78..fe870b673 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,7 +1,7 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from functools import partial -from typing import Dict +from typing import Dict, List from torch.fx import Graph, Node from .memory import activation_size, is_inplace @@ -39,16 +39,25 @@ class GraphInfo: bwd_flop (int): The backward FLOPs of a certain node. bwd_time (float): The real backward time (s) of a certain node. save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. + fwd_in (List): See the above illustration. + fwd_tmp (List): See the above illustration. + fwd_out (List): See the above illustration. fwd_mem_tmp (int): See the above illustration. fwd_mem_out (int): See the above illustration. bwd_mem_tmp (int): See the above illustration. bwd_mem_out (int): See the above illustration. """ + + # TODO(super-dainiu): removed redundant items, currently all of them are necessary for development + fwd_flop: int = 0 fwd_time: float = 0.0 bwd_flop: int = 0 bwd_time: float = 0.0 save_fwd_in: bool = False + fwd_in: List = field(default_factory=list) + fwd_tmp: List = field(default_factory=list) + fwd_out: List = field(default_factory=list) fwd_mem_tmp: int = 0 fwd_mem_out: int = 0 bwd_mem_tmp: int = 0 @@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool: return n.meta['phase'] == phase -def is_saved(n: Node): - return len(n.meta['saved_tensor']) - - def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked for keyword `phase`. @@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # the node, `fwd_mem_tmp` can be freed. if is_phase(n, Phase.PLACEHOLDER): - graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0 + graph_info.fwd_in += n.meta['saved_tensor'] if is_phase(n, Phase.FORWARD): - graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor']) + graph_info.fwd_tmp += n.meta['saved_tensor'] elif is_phase(n, Phase.BACKWARD): if len(n.users): graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py index b6beb7609..3dfdd2758 100644 --- a/colossalai/fx/profiler/experimental/__init__.py +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -1,4 +1,5 @@ from .registry import meta_profiler_function, meta_profiler_module +from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .profiler_function import * from .profiler_module import * from .profiler import profile_function, profile_method, profile_module diff --git a/colossalai/fx/profiler/experimental/memory.py b/colossalai/fx/profiler/experimental/memory.py new file mode 100644 index 000000000..601c4cf36 --- /dev/null +++ b/colossalai/fx/profiler/experimental/memory.py @@ -0,0 +1,42 @@ +# for PyTorch 1.11 compatibility uses +import torch +from torch.fx import Node, GraphModule +from typing import Union, Dict, List, Tuple + +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] + + +def calculate_fwd_in(n: Node) -> bool: + """A helper function to calculate `fwd_in` + + Args: + n (Node): a node from the graph + + Returns: + save_fwd_in (bool): the result of `save_fwd_in` + """ + return n.meta['save_fwd_in'] + + +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + return n.meta["fwd_mem_tmp"] + + +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + return n.meta['fwd_mem_out'] diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index 96233a9df..884de33e0 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,9 +1,11 @@ import torch -from torch.fx import Node +from torch.fx import Node, GraphModule from typing import Union, Dict, List, Tuple from . import META_COMPATIBILITY -__all__ = ['activation_size', 'parameter_size', 'is_inplace'] +__all__ = [ + 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" +] def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: @@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: elif isinstance(out, dict): value_list = [v for _, v in out.items()] act_size += activation_size(value_list) - elif isinstance(out, tuple) or isinstance(out, list): + elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): for element in out: act_size += activation_size(element) return act_size @@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int: return param_size +def calculate_fwd_in(n: Node) -> int: + """A helper function to calculate `fwd_in` + + Args: + n (Node): a node from the graph + + Returns: + fwd_in (int): the result of `fwd_in` + """ + return activation_size(n.meta["fwd_in"]) + + +def calculate_fwd_tmp(n: Node) -> int: + """A helper function to calculate `fwd_tmp` + Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. + + Args: + n (Node): a node from the graph + + Returns: + fwd_tmp (int): the result of `fwd_tmp` + """ + + def is_relu_node(n: Node) -> bool: + if n.op == 'call_function': + return n.target in [torch.nn.functional.relu] + elif n.op == 'call_module': + return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU] + return False + + if not is_relu_node(n): + return activation_size(n.meta["fwd_tmp"]) + return 0 + + +def calculate_fwd_out(n: Node) -> int: + """A helper function to calculate `fwd_out` + + Args: + n (Node): a node from the graph + + Returns: + fwd_out (int): the result of `fwd_out` + """ + + def intersect(a, b): + return {k: a[k] for k in a if k in b} + + fwd_in = dict() + for u in n.users: + fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}) + fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + return activation_size(intersect(fwd_in, fwd_out)) + + def is_inplace(n: Node): """Get the inplace argument from torch.fx.Node diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 3e2662eef..22298ef26 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -226,6 +226,7 @@ flop_mapping = { aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0), aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), + aten.embedding.default: elementwise_flop_counter(1, 0), } elementwise_flop_aten = [ @@ -304,10 +305,12 @@ zero_flop_aten = [ aten.transpose.int, aten._to_copy.default, aten.unsqueeze.default, + aten.unbind.int, aten._unsafe_view.default, aten.view.default, aten.where.self, aten.zero_.default, + aten.zeros_like.default, ] for op in zero_flop_aten: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 2bb83862e..30284f64a 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -18,6 +18,9 @@ __all__ = ['profile_function', 'profile_module', 'profile_method'] # track duplicated tensors between nodes cache = set() +# a global identifier for inplace ops +do_not_cache = False + def normalize_tuple(x): if not isinstance(x, tuple): @@ -223,10 +226,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G kwargs = tree_map(wrap, kwargs) def pack(x): - global cache - if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache: - x._node.meta['saved_tensor'] += [x] - cache.add(x._tensor.data_ptr) + global cache, do_not_cache + if isinstance(x, FlopTensor) and not x._tensor.uuid in cache: + tensor = x._tensor.detach() + tensor.uuid = x._tensor.uuid + x._node.meta['saved_tensor'] += [tensor] + if not do_not_cache: + cache.add(x._tensor.uuid) return x def unpack(x): @@ -245,16 +251,25 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - for tensor in normalize_tuple(out): - if is_autogradable(tensor) and tensor.requires_grad: - phase = Phase.BACKWARD - grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( - tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta')) - torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True) + if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))): + grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)] + phase = Phase.BACKWARD + torch.autograd.backward( + out, + grad_out, + ) graph_info = autograd_graph_analysis(subgraph) graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] - graph_info.fwd_mem_out = activation_size(out) + + def extract_tensor(x: Any): + if isinstance(x, MetaTensor): + tensor = x._tensor.detach() + tensor.uuid = x._tensor.uuid + return tensor + return x + + graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) def unwrap(x): return MetaTensor(x) if isinstance(x, torch.Tensor) else x @@ -279,32 +294,39 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # If there is an argument that this `call_function` is inplace, we should - # still run the profiling but discard some results regarding `target` - inplace = kwargs.get('inplace', False) - if inplace: - kwargs['inplace'] = False - if device == 'meta': - out, meta = _profile_meta(func, *args, **kwargs) - - # currently we set the fwd_mem_tmp of ReLU to zero - if target in [torch.nn.functional.relu]: - meta.save_fwd_in = False - meta.bwd_mem_out = 0 - meta.fwd_mem_tmp = 0 - else: - out, meta = _profile_concrete(func, *args, **kwargs) - # find the grad for parameter in args and kwargs param_size = 0 def get_param_size(x): + nonlocal param_size if isinstance(x, Parameter): param_size += activation_size(x) tree_map(get_param_size, args) tree_map(get_param_size, kwargs) + # If there is an argument that this `call_function` is inplace, we should + # still run the profiling but discard some results regarding `target` + global do_not_cache + inplace = kwargs.get('inplace', False) + if inplace or target in [torch.nn.functional.relu]: + do_not_cache = True + kwargs['inplace'] = False + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + # currently we set the fwd_mem_tmp of ReLU to zero + if target in [torch.nn.functional.relu]: + meta.fwd_in = [] + meta.fwd_tmp = [] + meta.bwd_mem_out = 0 + meta.fwd_mem_tmp = 0 + else: + out, meta = _profile_concrete(func, *args, **kwargs) + + if inplace: + kwargs['inplace'] = True + do_not_cache = False + meta.bwd_mem_out -= param_size return out, meta @@ -348,25 +370,30 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # If there is an argument that this `call_module` is inplace, we should - # still run the profiling but discard some results regarding `module`. - inplace = getattr(module, 'inplace', False) - # calculate parameter size param_size = parameter_size(module) - if inplace: + # If there is an argument that this `call_module` is inplace, we should + # still run the profiling but discard some results regarding `module`. + global do_not_cache + + inplace = getattr(module, 'inplace', False) + if inplace or type(module) in [torch.nn.ReLU]: + do_not_cache = True module.inplace = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) - - # currently we set the fwd_mem_tmp of ReLU to zero - if type(module) in [torch.nn.modules.activation.ReLU]: - meta.save_fwd_in = False + # currently we set the fwd_tmp of ReLU to [] + if type(module) in [torch.nn.ReLU]: + meta.fwd_in = [] + meta.fwd_tmp = [] meta.bwd_mem_out = 0 - meta.fwd_mem_tmp = 0 else: out, meta = _profile_concrete(func, *args, **kwargs) + if inplace: + + module.inplace = True + do_not_cache = False # grad for param will not be counted meta.bwd_mem_out -= param_size diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 173eb81d9..b380512a6 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,13 +1,20 @@ from copy import deepcopy -from typing import Optional, Union, overload +from typing import Optional import torch from torch.utils._pytree import tree_map, tree_flatten from torch.types import _bool, _dtype, _device -from functools import singledispatchmethod +import uuid +from .constant import ALIAS_ATEN __all__ = ['MetaTensor'] +def set_uuid(x): + if isinstance(x, torch.Tensor): + if not hasattr(x, 'uuid'): + setattr(x, 'uuid', uuid.uuid4()) + + class MetaTensor(torch.Tensor): """ A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. @@ -42,6 +49,7 @@ class MetaTensor(torch.Tensor): if not r._tensor.is_meta: r._tensor = r._tensor.to(torch.device('meta')) # only tensor not on `meta` should be copied to `meta` + set_uuid(r._tensor) return r def __repr__(self): @@ -73,6 +81,11 @@ class MetaTensor(torch.Tensor): # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) + # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy + # of the input + if func in ALIAS_ATEN: + setattr(out, 'uuid', args[0].uuid) + # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): @@ -84,7 +97,6 @@ class MetaTensor(torch.Tensor): return tree_map(wrap, out) - @singledispatchmethod def to(self, *args, **kwargs) -> torch.Tensor: """An extension of `torch.Tensor.to()` to MetaTensor @@ -101,14 +113,13 @@ class MetaTensor(torch.Tensor): MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') """ # this imitates c++ function in the way of @overload - return super().to(*args, **kwargs) - - @to.register - def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: - result = super().to(dtype, non_blocking, copy) if dtype is not None else self - return MetaTensor(deepcopy(result), fake_device=device) - - @to.register - def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: - result = super().to(dtype, non_blocking, copy) if dtype is not None else self - return MetaTensor(deepcopy(result), fake_device=device) + device = None + for arg in args: + if isinstance(arg, str) or isinstance(arg, _device): + device = arg + if 'device' in kwargs: + device = kwargs['device'] + result = super().to(*args, **kwargs) + if device is not None: + result = MetaTensor(deepcopy(result), fake_device=device) + return result diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 4dc1cdc2d..ff61e604c 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -13,6 +13,9 @@ from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.utils import free_port from colossalai.core import global_context as gpc import pytest +from colossalai import META_COMPATIBILITY +if META_COMPATIBILITY: + from colossalai.fx.profiler.tensor import MetaTensor try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -74,7 +77,7 @@ def _run_ckpt_solver(rank): m = model_cls(num_classes=5) graph = tracer.trace(root=m) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) - MetaInfoProp(gm).run(data) + MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda')) codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) if solver == solver_rotor: @@ -89,7 +92,6 @@ def _run_ckpt_solver(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') -@pytest.mark.skip('TODO: refactor ckpt solvers') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) @@ -111,7 +113,7 @@ def _run_ckpt_solver_torch11(rank): MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) if solver == solver_rotor: - gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True) else: gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." @@ -129,5 +131,5 @@ def test_ckpt_solver_torch11(): if __name__ == '__main__': _run_ckpt_solver(rank=0) - # test_ckpt_solver() - # test_ckpt_solver_torch11() + test_ckpt_solver() + test_ckpt_solver_torch11() diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 8f5d6abe5..ec30d0e76 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -1,3 +1,4 @@ +from colossalai.fx.passes.meta_info_prop import MetaInfoProp import torch import torchvision.models as tm from colossalai.fx import ColoTracer @@ -5,6 +6,9 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.algorithms import solver_rotor, linearize from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd import pytest +from colossalai import META_COMPATIBILITY +if META_COMPATIBILITY: + from colossalai.fx.profiler.tensor import MetaTensor try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -15,7 +19,7 @@ except: with_codegen = False -@pytest.mark.skip(reason='TODO: modify calculations in rotor') +@pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} @@ -26,6 +30,7 @@ def test_linearize(): graph = tracer.trace(model) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations()