mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.pull/1766/head
parent
63f250bbd4
commit
0584654c79
|
@ -1,7 +1,9 @@
|
|||
import math
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
__all__ = ['chen_greedy']
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
import math
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
from colossalai.fx.profiler.memory import calculate_fwd_in
|
||||
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||
|
||||
# global vairable to indicate whether the solver is failed
|
||||
SOLVER_FAILED = False
|
||||
|
||||
|
@ -18,7 +20,7 @@ SOLVER_FAILED = False
|
|||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
# paper link: https://hal.inria.fr/hal-02352969
|
||||
def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
Opt[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
|
||||
|
@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
|||
"""Get the forward xbar of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
|
@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
|
|||
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import subprocess
|
||||
import os
|
||||
import subprocess
|
||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
|
|
|
@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.fx
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
|||
DIM_HIDDEN = 16
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||
).cuda()
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
|
||||
gm = symbolic_trace(model)
|
||||
interp = ConcreteInfoProp(gm)
|
||||
interp.run(input_sample)
|
||||
print(interp.summary(unit='kb'))
|
||||
|
||||
|
||||
output of above code is
|
||||
print(interp.summary(unit='kb'))
|
||||
|
||||
|
||||
output of above code is
|
||||
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
|
@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
|||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
|
|
|
@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
|
|||
|
||||
import torch
|
||||
import torch.fx
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
|
||||
profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (
|
||||
GraphInfo,
|
||||
activation_size,
|
||||
calculate_fwd_in,
|
||||
calculate_fwd_out,
|
||||
calculate_fwd_tmp,
|
||||
profile_function,
|
||||
profile_method,
|
||||
profile_module,
|
||||
)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TensorMetadata(NamedTuple):
|
||||
|
@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
DIM_HIDDEN = 16
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||
)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
|
@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
interp = MetaInfoProp(gm)
|
||||
interp.run(input_sample)
|
||||
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
|
||||
|
||||
|
||||
# output of above code is
|
||||
|
||||
|
||||
# output of above code is
|
||||
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- --------------- ---------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
|
@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
from .._compatibility import is_compatible_with_meta
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||
from .opcount import flop_mapping
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .shard_utils import (
|
||||
calculate_bwd_time,
|
||||
calculate_fwd_in,
|
||||
calculate_fwd_out,
|
||||
calculate_fwd_time,
|
||||
calculate_fwd_tmp,
|
||||
)
|
||||
from .tensor import MetaTensor
|
||||
else:
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
|
||||
from .dataflow import GraphInfo
|
||||
from .memory import activation_size, is_inplace, parameter_size
|
||||
from .memory_utils import activation_size, is_inplace, parameter_size
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Dict, List
|
|||
from torch.fx import Graph, Node
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .memory import activation_size, is_inplace
|
||||
from .memory_utils import activation_size, is_inplace
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
|
@ -29,7 +29,7 @@ class GraphInfo:
|
|||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | \_____ | |
|
||||
|
@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
|||
Nodes should have attribute `out` indicating the output of each node.
|
||||
============================================================================
|
||||
Placeholder ----> p o <---- We need to keep track of grad out
|
||||
|\________ |
|
||||
|\________ |
|
||||
↓ ↘|
|
||||
f --------> b
|
||||
|\ \_____ ↑
|
||||
| \ ↘ /
|
||||
f f ----> b <---- Not every forward result needs to be saved for backward
|
||||
| \____ ↑
|
||||
↘ ↘|
|
||||
↘ ↘|
|
||||
f ----> b <---- Backward can be freed as soon as it is required no more.
|
||||
↘ ↗
|
||||
l
|
||||
=============================================================================
|
||||
=============================================================================
|
||||
Args:
|
||||
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch.fx.node import Argument, Target
|
||||
|
||||
from ..._compatibility import compatibility
|
||||
from ..memory import activation_size
|
||||
from ..memory_utils import activation_size
|
||||
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
|
||||
|
@ -27,7 +27,7 @@ class GraphInfo:
|
|||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | | \_____ | |
|
||||
|
@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
|
|||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
Unfortunately, backward memory cost and FLOPs are estimated results.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
|
@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
|
|||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module) -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
|
||||
Example:
|
||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .._compatibility import compatibility, is_compatible_with_meta
|
||||
|
||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
"""
|
||||
act_size = 0
|
||||
if isinstance(out, torch.Tensor):
|
||||
if out.is_quantized:
|
||||
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
|
||||
else:
|
||||
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
|
||||
elif isinstance(out, dict):
|
||||
value_list = [v for _, v in out.items()]
|
||||
act_size += activation_size(value_list)
|
||||
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
||||
for element in out:
|
||||
act_size += activation_size(element)
|
||||
return act_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def parameter_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate parameter size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
|
||||
Returns:
|
||||
int: The parameter size
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
return param_size
|
||||
|
||||
|
||||
def is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
|
||||
Returns:
|
||||
bool: indicates whether this op is inplace
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
if is_compatible_with_meta():
|
||||
from .constants import ALIAS_ATEN
|
||||
if n.target in ALIAS_ATEN:
|
||||
inplace = True
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
|
||||
return inplace
|
|
@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
|
|||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||
from .memory import activation_size, parameter_size
|
||||
from .memory_utils import activation_size, parameter_size
|
||||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
|
||||
|
@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
|
@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
record the memory cost and FLOPs of the execution.
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
|||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
|
||||
Example:
|
||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
|
|
|
@ -1,58 +1,18 @@
|
|||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.fx import Node
|
||||
|
||||
from .._compatibility import compatibility, is_compatible_with_meta
|
||||
from .memory_utils import activation_size
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
__all__ = [
|
||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
"""
|
||||
act_size = 0
|
||||
if isinstance(out, torch.Tensor):
|
||||
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
|
||||
elif isinstance(out, dict):
|
||||
value_list = [v for _, v in out.items()]
|
||||
act_size += activation_size(value_list)
|
||||
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
||||
for element in out:
|
||||
act_size += activation_size(element)
|
||||
return act_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def parameter_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate parameter size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
|
||||
Returns:
|
||||
int: The parameter size
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
return param_size
|
||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_in(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_in`
|
||||
"""A helper function to calculate `fwd_in` (with sharding spec)
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
|
|||
Returns:
|
||||
fwd_in (int): the result of `fwd_in`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
return activation_size(n.meta["fwd_in"])
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp`
|
||||
"""A helper function to calculate `fwd_tmp` (with sharding spec)
|
||||
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
||||
|
||||
Args:
|
||||
|
@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
|
|||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
def is_relu_like_node(n: Node) -> bool:
|
||||
"""Check if a node is a ReLU-like node.
|
||||
ReLU-like nodes have the following properties:
|
||||
|
@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out`
|
||||
"""A helper function to calculate `fwd_out` (with sharding spec)
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int:
|
|||
fwd_out (int): the result of `fwd_out`
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
def intersect(a, b):
|
||||
return {k: a[k] for k in a if k in b}
|
||||
|
||||
|
@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int:
|
|||
return activation_size(intersect(fwd_in, fwd_out))
|
||||
|
||||
|
||||
def is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
|
||||
def calculate_fwd_time(n: Node) -> float:
|
||||
"""A helper function to calculate `fwd_time` (with sharding spec)
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
|
||||
n (Node): a node from the graph
|
||||
Returns:
|
||||
bool: indicates whether this op is inplace
|
||||
fwd_time (float): the result of `fwd_time`
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
if is_compatible_with_meta():
|
||||
from .constants import ALIAS_ATEN
|
||||
if n.target in ALIAS_ATEN:
|
||||
inplace = True
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["fwd_flop"]
|
||||
|
||||
return inplace
|
||||
|
||||
def calculate_bwd_time(n: Node) -> float:
|
||||
"""A helper function to calculate `bwd_time` (with sharding spec)
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
Returns:
|
||||
bwd_time (float): the result of `bwd_time`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["bwd_flop"]
|
|
@ -1,7 +1,5 @@
|
|||
from colossalai.fx.profiler.memory import activation_size
|
||||
import torch
|
||||
from torch.fx import Node, Graph
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.fx import Graph, Node
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
|
|
|
@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.fx
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from gpt_utils import gpt2_medium, gpt2_xl
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
|
Loading…
Reference in New Issue