[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
pull/1766/head
Super Daniel 2022-10-26 14:24:41 +08:00 committed by GitHub
parent 63f250bbd4
commit 0584654c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 177 additions and 122 deletions

View File

@ -1,7 +1,9 @@
import math
from typing import List, Set, Tuple from typing import List, Set, Tuple
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy'] __all__ = ['chen_greedy']

View File

@ -1,15 +1,17 @@
import math
import sys import sys
from typing import List, Tuple from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed # global vairable to indicate whether the solver is failed
SOLVER_FAILED = False SOLVER_FAILED = False
@ -18,7 +20,7 @@ SOLVER_FAILED = False
# https://gitlab.inria.fr/hiepacs/rotor # https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969 # paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple: def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing: """Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node """Get the forward xbar of a node
Args: Args:
node (List[Node]): List of torch.fx Node, node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph indicates a node in linearized graph
Returns: Returns:
@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found # build module if module not found
except ModuleNotFoundError: except ModuleNotFoundError:
import subprocess
import os import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen( result = subprocess.Popen(

View File

@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch import torch
import torch.fx import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter): class ConcreteInfoProp(torch.fx.Interpreter):
@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16 DIM_HIDDEN = 16
DIM_OUT = 16 DIM_OUT = 16
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT), torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda() ).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model) gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm) interp = ConcreteInfoProp(gm)
interp.run(input_sample) interp.run(input_sample)
print(interp.summary(unit='kb')) print(interp.summary(unit='kb'))
output of above code is output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str: def summary(self, unit: str = 'MB') -> str:
""" """
Summarizes the memory and FLOPs statistics of the `GraphModule` in Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module tabular format. Note that this API requires the ``tabulate`` module
to be installed. to be installed.
""" """
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py

View File

@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import torch import torch
import torch.fx import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple): class TensorMetadata(NamedTuple):
@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16 DIM_HIDDEN = 16
DIM_OUT = 16 DIM_OUT = 16
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT), torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
) )
input_sample = torch.rand(BATCH_SIZE, DIM_IN) input_sample = torch.rand(BATCH_SIZE, DIM_IN)
@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
interp = MetaInfoProp(gm) interp = MetaInfoProp(gm)
interp.run(input_sample) interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is # output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- --------- ----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str: def summary(self, unit: str = 'MB') -> str:
""" """
Summarizes the memory and FLOPs statistics of the `GraphModule` in Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module tabular format. Note that this API requires the ``tabulate`` module
to be installed. to be installed.
""" """
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py

View File

@ -1,12 +1,18 @@
from .._compatibility import is_compatible_with_meta from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta(): if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping from .opcount import flop_mapping
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module
from .shard_utils import (
calculate_bwd_time,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from .tensor import MetaTensor from .tensor import MetaTensor
else: else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory import activation_size, is_inplace, parameter_size from .memory_utils import activation_size, is_inplace, parameter_size

View File

@ -6,7 +6,7 @@ from typing import Dict, List
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .._compatibility import compatibility from .._compatibility import compatibility
from .memory import activation_size, is_inplace from .memory_utils import activation_size, is_inplace
class Phase(Enum): class Phase(Enum):
@ -29,7 +29,7 @@ class GraphInfo:
placeholders saved for | | \__________ | | placeholders saved for | | \__________ | |
backward. | | \ | | backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <----- | [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory | | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass. | / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | \_____ | | in [fwd_tmp] because | | \_____ | |
@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
Nodes should have attribute `out` indicating the output of each node. Nodes should have attribute `out` indicating the output of each node.
============================================================================ ============================================================================
Placeholder ----> p o <---- We need to keep track of grad out Placeholder ----> p o <---- We need to keep track of grad out
|\________ | |\________ |
| |
f --------> b f --------> b
|\ \_____ |\ \_____
| \ / | \ /
f f ----> b <---- Not every forward result needs to be saved for backward f f ----> b <---- Not every forward result needs to be saved for backward
| \____ | \____
| |
f ----> b <---- Backward can be freed as soon as it is required no more. f ----> b <---- Backward can be freed as soon as it is required no more.
l l
============================================================================= =============================================================================
Args: Args:
graph (Graph): The autograd graph with nodes marked for keyword `phase`. graph (Graph): The autograd graph with nodes marked for keyword `phase`.

View File

@ -1,5 +1,5 @@
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module
from .profiler_function import * from .profiler_function import *
from .profiler_module import * from .profiler_module import *
from .registry import meta_profiler_function, meta_profiler_module from .registry import meta_profiler_function, meta_profiler_module
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp

View File

@ -5,7 +5,7 @@ import torch
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from ..._compatibility import compatibility from ..._compatibility import compatibility
from ..memory import activation_size from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module from .registry import meta_profiler_function, meta_profiler_module
@ -27,7 +27,7 @@ class GraphInfo:
placeholders saved for | | \__________ | | placeholders saved for | | \__________ | |
backward. | | \ | | backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <----- | [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory | | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass. | / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | | in [fwd_tmp] because | | | \_____ | |
@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable: def profile_function(target: 'Target') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results. Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available. Only original `torch.nn.functional` are available.
Examples: Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta') >>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu >>> func = torch.nn.functional.relu
@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable: def profile_module(module: torch.nn.Module) -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available. Only original `torch.nn` are available.
Example: Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta') >>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3) >>> mod = torch.nn.Conv2d(3, 128, 3)

View File

@ -0,0 +1,71 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
if out.is_quantized:
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
else:
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

View File

@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from .._compatibility import compatibility from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor from .tensor import MetaTensor
@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available. Only original `torch.nn.functional` are available.
Examples: Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta') >>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu >>> func = torch.nn.functional.relu
@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def profile_method(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
""" """
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution. record the memory cost and FLOPs of the execution.
Warnings: Warnings:
You may only use tensors with `device=meta` for this wrapped function. You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available. Only original `torch.nn` are available.
Example: Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta') >>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3) >>> mod = torch.nn.Conv2d(3, 128, 3)

View File

@ -1,58 +1,18 @@
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta from .._compatibility import compatibility, is_compatible_with_meta
from .memory_utils import activation_size
if is_compatible_with_meta(): if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = [ __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
@compatibility(is_backward_compatible=False)
def calculate_fwd_in(n: Node) -> int: def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in` """A helper function to calculate `fwd_in` (with sharding spec)
Args: Args:
n (Node): a node from the graph n (Node): a node from the graph
@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
Returns: Returns:
fwd_in (int): the result of `fwd_in` fwd_in (int): the result of `fwd_in`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
return activation_size(n.meta["fwd_in"]) return activation_size(n.meta["fwd_in"])
@compatibility(is_backward_compatible=False)
def calculate_fwd_tmp(n: Node) -> int: def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp` """A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args: Args:
@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp` fwd_tmp (int): the result of `fwd_tmp`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
def is_relu_like_node(n: Node) -> bool: def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node. """Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties: ReLU-like nodes have the following properties:
@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
return 0 return 0
@compatibility(is_backward_compatible=False)
def calculate_fwd_out(n: Node) -> int: def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out` """A helper function to calculate `fwd_out` (with sharding spec)
Args: Args:
n (Node): a node from the graph n (Node): a node from the graph
@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int:
fwd_out (int): the result of `fwd_out` fwd_out (int): the result of `fwd_out`
""" """
# TODO(super-dainiu): should divide the memory by sharding spec
def intersect(a, b): def intersect(a, b):
return {k: a[k] for k in a if k in b} return {k: a[k] for k in a if k in b}
@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int:
return activation_size(intersect(fwd_in, fwd_out)) return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node): def calculate_fwd_time(n: Node) -> float:
"""Get the inplace argument from torch.fx.Node """A helper function to calculate `fwd_time` (with sharding spec)
Args: Args:
node (Node): torch.fx.Node n (Node): a node from the graph
Returns: Returns:
bool: indicates whether this op is inplace fwd_time (float): the result of `fwd_time`
""" """
inplace = False # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
if n.op == "call_function": return n.meta["fwd_flop"]
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
def calculate_bwd_time(n: Node) -> float:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]

View File

@ -1,7 +1,5 @@
from colossalai.fx.profiler.memory import activation_size
import torch import torch
from torch.fx import Node, Graph from torch.fx import Graph, Node
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map

View File

@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.fx import torch.fx
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from gpt_utils import gpt2_medium, gpt2_xl from gpt_utils import gpt2_medium, gpt2_xl
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor