mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
6.9 KiB
174 lines
6.9 KiB
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, Tuple
|
|
|
|
import torch
|
|
from torch.fx.node import Argument, Target
|
|
|
|
from ..._compatibility import compatibility
|
|
from ..memory_utils import activation_size
|
|
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
|
from .registry import meta_profiler_function, meta_profiler_module
|
|
|
|
__all__ = ["profile_function", "profile_module", "profile_method"]
|
|
|
|
|
|
# this is for compatibility use
|
|
@compatibility(is_backward_compatible=True)
|
|
@dataclass
|
|
class GraphInfo:
|
|
"""
|
|
GraphInfo is a dataclass for MetaInfo, which measures
|
|
the execution memory cost and FLOPs with `MetaTensor`.
|
|
The dataflow analysis is conducted on a single node of the FX graph.
|
|
============================================================================
|
|
-------------------------------
|
|
| Node |
|
|
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
|
placeholders saved for | | \__________ | |
|
|
backward. | | \ | |
|
|
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
|
| | \_________ | | [bwd_tmp] marks the peak memory
|
|
| / \ \ | | in backward pass.
|
|
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
|
in [fwd_tmp] because | | | \_____ | |
|
|
it is not saved for | | | \ | |
|
|
backward. -------------------------------
|
|
============================================================================
|
|
Attributes:
|
|
fwd_flop (int): The forward FLOPs of a certain node
|
|
bwd_flop (int): The backward FLOPs of a certain node.
|
|
fwd_mem_in (int): See the above illustration.
|
|
fwd_mem_tmp (int): See the above illustration.
|
|
bwd_mem_tmp (int): See the above illustration.
|
|
bwd_mem_out (int): See the above illustration.
|
|
"""
|
|
|
|
fwd_flop: int = 0
|
|
bwd_flop: int = 0
|
|
fwd_mem_in: int = 0
|
|
fwd_mem_tmp: int = 0
|
|
bwd_mem_tmp: int = 0
|
|
bwd_mem_out: int = 0
|
|
|
|
|
|
CALL_FUNCTION_MSG = """
|
|
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
|
from colossalai.fx.profiler.experimental import meta_profiler_function
|
|
@meta_profiler_function.register(YOUR_FUNCTION)
|
|
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
|
|
flops = ...
|
|
macs = ...
|
|
return flops, macs
|
|
"""
|
|
CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
|
|
CALL_MODULE_MSG = """
|
|
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
|
from colossalai.fx.profiler.experimental import meta_profiler_module
|
|
@meta_profiler_module.register(YOUR_MODULE)
|
|
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
|
flops = ...
|
|
macs = ...
|
|
return flops, macs
|
|
"""
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def profile_function(target: "Target") -> Callable:
|
|
"""
|
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
|
record the memory cost and FLOPs of the execution.
|
|
Unfortunately, backward memory cost and FLOPs are estimated results.
|
|
|
|
Warnings:
|
|
You may only use tensors with `device=meta` for this wrapped function.
|
|
Only original `torch.nn.functional` are available.
|
|
|
|
Examples:
|
|
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
|
>>> func = torch.nn.functional.relu
|
|
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
|
|
"""
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
|
target.__name__
|
|
), CALL_FUNCTION_MSG.format(target)
|
|
|
|
fwd_tmp = 0
|
|
fwd_out = 0
|
|
out = func(*args, **kwargs)
|
|
if target not in INPLACE_OPS and not kwargs.get("inplace", False):
|
|
fwd_out = activation_size(out)
|
|
if meta_profiler_function.has(target):
|
|
profiler = meta_profiler_function.get(target)
|
|
else:
|
|
profiler = meta_profiler_function.get(target.__name__)
|
|
fwd_flop, _ = profiler(*args, **kwargs)
|
|
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
|
|
|
f.__name__ = target.__name__
|
|
func = target
|
|
return f
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def profile_method(target: "Target") -> Callable:
|
|
"""
|
|
Wrap a `call_method` node
|
|
record the memory cost and FLOPs of the execution.
|
|
|
|
Warnings:
|
|
This is not fully implemented and you may follow the error message to debug.
|
|
"""
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
# args[0] is the `self` object for this method call
|
|
self_obj, *args_tail = args
|
|
|
|
# execute the method and return the result
|
|
assert isinstance(target, str), f"{target} instance is not str."
|
|
|
|
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
|
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
|
|
target, INPLACE_METHOD, NON_INPLACE_METHOD
|
|
)
|
|
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
|
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
|
|
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
|
|
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
|
|
|
return f
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def profile_module(module: torch.nn.Module) -> Callable:
|
|
"""
|
|
Wrap a `call_module` node or `torch.nn` in order to
|
|
record the memory cost and FLOPs of the execution.
|
|
|
|
Warnings:
|
|
You may only use tensors with `device=meta` for this wrapped function.
|
|
Only original `torch.nn` are available.
|
|
|
|
Example:
|
|
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
|
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
|
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
|
|
"""
|
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
|
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
|
|
|
|
fwd_tmp = 0
|
|
fwd_out = 0
|
|
out = func(*args, **kwargs)
|
|
if getattr(module, "inplace", False):
|
|
fwd_out = activation_size(out)
|
|
profiler = meta_profiler_module.get(type(module))
|
|
fwd_flop, _ = profiler(module, *args, **kwargs)
|
|
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
|
|
|
|
f.__name__ = module.__class__.__name__
|
|
func = module.forward
|
|
return f
|