[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
pull/1766/head
Super Daniel 2022-10-26 14:24:41 +08:00 committed by GitHub
parent 63f250bbd4
commit 0584654c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 177 additions and 122 deletions

View File

@ -1,7 +1,9 @@
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']

View File

@ -1,15 +1,17 @@
import math
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
@ -18,7 +20,7 @@ SOLVER_FAILED = False
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found
except ModuleNotFoundError:
import subprocess
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(

View File

@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):
@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb'))
output of above code is
print(interp.summary(unit='kb'))
output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py

View File

@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
# output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py

View File

@ -1,12 +1,18 @@
from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping
from .profiler import profile_function, profile_method, profile_module
from .shard_utils import (
calculate_bwd_time,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory import activation_size, is_inplace, parameter_size
from .memory_utils import activation_size, is_inplace, parameter_size

View File

@ -6,7 +6,7 @@ from typing import Dict, List
from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace
from .memory_utils import activation_size, is_inplace
class Phase(Enum):
@ -29,7 +29,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | \_____ | |
@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
|\________ |
|
f --------> b
|\ \_____
| \ /
f f ----> b <---- Not every forward result needs to be saved for backward
| \____
|
|
f ----> b <---- Backward can be freed as soon as it is required no more.
l
=============================================================================
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked for keyword `phase`.

View File

@ -1,5 +1,5 @@
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .registry import meta_profiler_function, meta_profiler_module
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp

View File

@ -5,7 +5,7 @@ import torch
from torch.fx.node import Argument, Target
from ..._compatibility import compatibility
from ..memory import activation_size
from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
@ -27,7 +27,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)

View File

@ -0,0 +1,71 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
if out.is_quantized:
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
else:
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

View File

@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size
from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
record the memory cost and FLOPs of the execution.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)

View File

@ -1,58 +1,18 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
from .memory_utils import activation_size
if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=False)
def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in`
"""A helper function to calculate `fwd_in` (with sharding spec)
Args:
n (Node): a node from the graph
@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
Returns:
fwd_in (int): the result of `fwd_in`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
return activation_size(n.meta["fwd_in"])
@compatibility(is_backward_compatible=False)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
"""A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
return 0
@compatibility(is_backward_compatible=False)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
"""A helper function to calculate `fwd_out` (with sharding spec)
Args:
n (Node): a node from the graph
@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int:
fwd_out (int): the result of `fwd_out`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def intersect(a, b):
return {k: a[k] for k in a if k in b}
@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int:
return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
def calculate_fwd_time(n: Node) -> float:
"""A helper function to calculate `fwd_time` (with sharding spec)
Args:
node (Node): torch.fx.Node
n (Node): a node from the graph
Returns:
bool: indicates whether this op is inplace
fwd_time (float): the result of `fwd_time`
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_flop"]
return inplace
def calculate_bwd_time(n: Node) -> float:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]

View File

@ -1,7 +1,5 @@
from colossalai.fx.profiler.memory import activation_size
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.fx import Graph, Node
from torch.utils._pytree import tree_map

View File

@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union
import torch
import torch.fx
import torchvision.models as tm
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from gpt_utils import gpt2_medium, gpt2_xl
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor