diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 1a1e14957..1d2638a02 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,10 +1,12 @@ +from dataclasses import asdict +from colossalai.fx.profiler import GraphInfo import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map from typing import Any, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size +from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size @compatibility(is_backward_compatible=True) @@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: class MetaInfoProp(torch.fx.Interpreter): """ Execute an FX graph Node-by-Node with meta tensor and - record the shape, FLOPs, MACs and type of the result + record the memory usage, FLOPs, and type of the result into the corresponding node. Usage: @@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter): Returns: Any: The result of executing ``n`` """ - result, flop_count, mem_stat = super().run_node(n) + result, meta_info = super().run_node(n) def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): @@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter): else: return TensorMetadata(None, None, False, None, 0, False) - meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = meta + tensor_meta = tree_map(extract_tensor_meta, result) + n.meta['tensor_meta'] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', mem_stat[1]) - setattr(n, 'fwd_flop', flop_count[0]) - setattr(n, 'bwd_flop', flop_count[1]) - setattr(n, 'fwd_tmp', mem_stat[0]) - setattr(n, 'fwd_out', mem_stat[1]) - setattr(n, 'bwd_tmp', mem_stat[2]) - setattr(n, 'bwd_out', mem_stat[3]) + setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) + for par in n.all_input_nodes: + par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) n.meta['type'] = type(result) + # retain the autograd graph for param in self.module.parameters(): param.grad = None + return result # Main Node running APIs @@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter): Returns: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - result = super().placeholder(target, args, kwargs) - # A placeholder node only has activation - return result, (0, 0), (0, activation_size(result), 0, 0) + return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter): Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0) + return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: @@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter): Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ assert not isinstance(target, str) return profile_function(target)(*args, **kwargs) @@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter): Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ return profile_method(target)(*args, **kwargs) @@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter): Return result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ # Retrieve executed args and kwargs values from the environment # Execute the method and return the result @@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter): Return: result (Any): The argument value that was retrieved - flop_count (Tuple): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], (0, 0), (0, 0, 0, 0) + return args[0], GraphInfo(fwd_mem_in=activation_size(args[0])) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 1b46bd494..fb19618b2 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -2,8 +2,9 @@ from ... import META_COMPATIBILITY if META_COMPATIBILITY: from .opcount import flop_mapping from .tensor import MetaTensor - from .profiler import profile_function, profile_method, profile_module, _profile + from .profiler import profile_function, profile_method, profile_module else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module +from .dataflow import GraphInfo from .memory import parameter_size, activation_size diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py new file mode 100644 index 000000000..f6efbf312 --- /dev/null +++ b/colossalai/fx/profiler/dataflow.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict +from torch.fx import Graph, Node +from .memory import activation_size + + +class Stage(Enum): + FORWARD = 0 + LOSS = 1 + BACKWARD = 2 + PLACEHOLDER = 3 + + +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + +def is_forward(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == Stage.FORWARD + + +def is_loss(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == Stage.LOSS + + +def is_placeholder(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == Stage.PLACEHOLDER + + +def is_backward(n: Node): + assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' + return n.meta['stage'] == Stage.BACKWARD + + +def is_saved(n: Node): + return n.meta.get('saved', False) + + +def autograd_graph_analysis(graph: Graph) -> GraphInfo: + """Analyze the autograd node dependencies and find out the memory usage. + Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + Nodes should have attribute `out` indicating the output of each node. + ============================================================================ + Placeholder ----> p o <---- We need to keep track of grad out + |\________ | + ↓ ↘| + f --------> b + |\ \_____ ↑ + | \ ↘ / + f f ----> b <---- Not every forward result needs to be saved for backward + | \____ ↑ + ↘ ↘| + f ----> b <---- Backward can be freed as soon as it is required no more. + ↘ ↗ + l + ============================================================================= + Args: + graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + + Returns: + graph_info (GraphInfo): Meta information for the dataflow. + """ + + def _peak_memory(deps: Dict[Node, int]): + bwd_tmp = 0 + for k, v in deps.items(): + if v > 0: + bwd_tmp += activation_size(k.meta['out']) + return bwd_tmp + + # deps is used to track all the memory dependencies of the graph. + deps = {} + graph_info = GraphInfo() + + for n in graph.nodes: + n: Node + if is_saved(n) and not any(map(is_loss, n.users)): + # A forward tensor who is marked `save` but is not + # an input to `loss` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_in`. + # Any `fwd_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint + # the node, `fwd_tmp` can be freed. + if is_placeholder(n): + graph_info.fwd_mem_in += activation_size(n.meta['out']) + if is_forward(n): + graph_info.fwd_mem_tmp += activation_size(n.meta['out']) + elif is_backward(n): + if len(n.users): + # liveness analysis is only used in backward + deps[n] = len(n.users) + graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) + for input_n in n.all_input_nodes: + if input_n in deps: + deps[input_n] -= 1 + else: + # basically a backward node without user is a `grad_out` node + graph_info.bwd_mem_out += activation_size(n.meta['out']) + return graph_info diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 46d4add3c..954e8b49b 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target @@ -6,6 +7,44 @@ from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLAC __all__ = ['profile_function', 'profile_module', 'profile_method'] + +# this is for compatibility use +@dataclass +class GraphInfo: + """ + GraphInfo is a dataclass for MetaInfo, which measures + the execution memory cost and FLOPs with `MetaTensor`. + The dataflow analysis is conducted on a single node of the FX graph. + ============================================================================ + ------------------------------- + | Node | + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + placeholders saved for | | \__________ | | + backward. | | \ | | + | [fwd_tmp] ------> [bwd_tmp] | <----- + | | \_________ | | [bwd_tmp] marks the peak memory + | / \ \ | | in backward pass. + [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- + in [fwd_tmp] because | | | \_____ | | + it is not saved for | | | \ | | + backward. ------------------------------- + ============================================================================ + Attributes: + fwd_flop (int): The forward FLOPs of a certain node + bwd_flop (int): The backward FLOPs of a certain node. + fwd_mem_in (int): See the above illustration. + fwd_mem_tmp (int): See the above illustration. + bwd_mem_tmp (int): See the above illustration. + bwd_mem_out (int): See the above illustration. + """ + fwd_flop: int = 0 + bwd_flop: int = 0 + fwd_mem_in: int = 0 + fwd_mem_tmp: int = 0 + bwd_mem_tmp: int = 0 + bwd_mem_out: int = 0 + + CALL_FUNCTION_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n @@ -59,7 +98,7 @@ def profile_function(target: 'Target') -> Callable: else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target @@ -88,7 +127,7 @@ def profile_method(target: 'Target') -> Callable: # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) - return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f @@ -118,7 +157,7 @@ def profile_module(module: torch.nn.Module) -> Callable: fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) - return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) + return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index be5106422..c023d0d1e 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -14,12 +14,10 @@ if META_COMPATIBILITY: INPLACE_ATEN = [ aten.add_.Tensor, - aten.add.Tensor, aten.sub_.Tensor, aten.div_.Tensor, aten.div_.Scalar, aten.mul_.Tensor, - aten.mul.Tensor, aten.bernoulli_.float, # inplace reshaping diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 8f9fb92e0..347c68c3a 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,13 +1,16 @@ +from dataclasses import dataclass +from enum import auto from typing import Callable, Any, Dict, Tuple import torch -from torch.fx import Graph +from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS +from .dataflow import autograd_graph_analysis, Stage +from .memory import WEIRD_OPS from .tensor import MetaTensor from .opcount import flop_mapping -__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile'] +__all__ = ['profile_function', 'profile_module', 'profile_method'] def normalize_tuple(x): @@ -20,8 +23,9 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: - """Profile a Callable function with args and kwargs. +def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]: + """ + Profile a Callable function with args and kwargs. Args: target (Callable): A Callable function @@ -29,25 +33,32 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: kwargs (Any): Argument Returns: - out (Tuple[Any, ...]): The argument value that was retrieved - flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop). - mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out) + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ + # This subgraph traces aten level ops inside one node. + subgraph = Graph() + # `flop_count`` serves as a global dictionary to store results. flop_count = { - 'f': 0, - 'l': 0, - 'b': 0, + Stage.FORWARD: 0, + Stage.LOSS: 0, + Stage.BACKWARD: 0, } - temp = { - 'f': [], - 'l': [], - 'b': [], - } - stage = 'f' + # `stage` will mark the stage of autograd from outside scope. + stage = Stage.FORWARD + + # FlopTensor not only get the flop statistics of a single node, + # it also build a full autograd graph for this node. + # This makes sure we can analyze the dependencies of memory, and + # decide which forward intermediate results should be kept until + # backward is executed. + # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): + _node: Node + def __repr__(self): if self.grad_fn: return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" @@ -56,66 +67,98 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def get_node(x): + return None if not hasattr(x, '_node') else x._node + + args_node = tree_map(get_node, args) + kwargs_node = tree_map(get_node, kwargs) + node = subgraph.create_node('call_function', func, args_node, kwargs_node) + def unwrap(x): + # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): x = FlopTensor(x.to('meta')) return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - def to_meta(x): - return x.to('meta') if isinstance(x, torch.Tensor) else x - args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) - if func not in INPLACE_ATEN: - temp[stage].append(tree_map(to_meta, normalize_tuple(out))) + node.meta['out'] = normalize_tuple(out) + node.meta['stage'] = stage def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x - return tree_map(wrap, out) + def set_node(x): + x._node = node + out = tree_map(wrap, out) + tree_map(set_node, out) + return out + + # `WEIRD_OPS` are tough to handle because they don't accept autograd + # on meta tensor. if target not in WEIRD_OPS: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x else: def wrap(x): - return FlopTensor( - x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + return FlopTensor(x.detach().requires_grad_( + False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x + # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) - if isinstance(target, str): - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - out = getattr(self_obj, target)(*args_tail, **kwargs) - else: - out = target(*args, **kwargs) + def set_placeholder(x): + if isinstance(x, FlopTensor): + x._node = subgraph.create_node('placeholder', + 'placeholder', (subgraph._root,), + name=subgraph._graph_namespace.create_name('input', x._tensor)) + x._node.meta['stage'] = Stage.PLACEHOLDER + x._node.meta['out'] = (x._tensor,) + tree_map(set_placeholder, args) + tree_map(set_placeholder, kwargs) + + def pack(x): + if isinstance(x, FlopTensor): + x._node.meta['saved'] = True + return x + + def unpack(x): + return x + + # mark saved tensors with saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + if isinstance(target, str): + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + out = getattr(self_obj, target)(*args_tail, **kwargs) + else: + out = target(*args, **kwargs) + + # If the output is not a floating point `torch.Tensor` or it does not + # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: - stage = 'l' + stage = Stage.LOSS loss = out.sum() - stage = 'b' + stage = Stage.BACKWARD loss.backward() - fwd_flop = flop_count['f'] - bwd_flop = flop_count['b'] - - fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0 - fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0 - bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0 + graph_info = autograd_graph_analysis(subgraph) + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD] def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x - return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0) + return tree_map(unwrap, out), graph_info def profile_function(target: 'Target') -> Callable: @@ -130,17 +173,15 @@ def profile_function(target: 'Target') -> Callable: Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) + >>> output, meta_info = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - if kwargs.get('inplace', False): - args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) - kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) - out = func(*args, **kwargs) - return out, (0, 0), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + + # If there is an argument that this `call_function` is inplace, we should + # skip the autograd profiling. + out, meta = _profile(func, *args, **kwargs) + return out, meta f.__name__ = target.__name__ func = target @@ -156,8 +197,8 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, flop_count, mem_stat = _profile(target, *args, **kwargs) - return out, flop_count, mem_stat + out, meta = _profile(target, *args, inplace=False, **kwargs) + return out, meta return f @@ -174,17 +215,15 @@ def profile_module(module: torch.nn.Module) -> Callable: Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) - >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) + >>> output, meta_info = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - if getattr(module, 'inplace', False): - args = tree_map(lambda x: x.to('meta'), args) - kwargs = tree_map(lambda x: x.to('meta'), kwargs) - out = func(*args, **kwargs) - return out, (out.numel(), out.numel()), (0, 0, 0, 0) - out, flop_count, mem_stat = _profile(func, *args, **kwargs) - return out, flop_count, mem_stat + + # If there is an argument that this `call_module` is inplace, we should + # skip the autograd profiling. + out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) + return out, meta f.__name__ = module.__class__.__name__ func = module.forward