mirror of https://github.com/hpcaitech/ColossalAI
150 lines
4.9 KiB
Python
150 lines
4.9 KiB
Python
from typing import Dict, List, Tuple, Union
|
|
|
|
import torch
|
|
from torch.fx import GraphModule, Node
|
|
|
|
from .._compatibility import compatibility, is_compatible_with_meta
|
|
|
|
if is_compatible_with_meta():
|
|
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
|
|
|
__all__ = [
|
|
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
|
]
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|
"""Calculate activation size of a node.
|
|
|
|
Args:
|
|
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
|
|
|
Returns:
|
|
int: The activation size
|
|
"""
|
|
act_size = 0
|
|
if isinstance(out, torch.Tensor):
|
|
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
|
|
elif isinstance(out, dict):
|
|
value_list = [v for _, v in out.items()]
|
|
act_size += activation_size(value_list)
|
|
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
|
for element in out:
|
|
act_size += activation_size(element)
|
|
return act_size
|
|
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
def parameter_size(mod: torch.nn.Module) -> int:
|
|
"""Calculate parameter size of a node.
|
|
|
|
Args:
|
|
mod (torch.nn.Module): The target `torch.nn.Module`
|
|
|
|
Returns:
|
|
int: The parameter size
|
|
"""
|
|
param_size = 0
|
|
for param in mod.parameters():
|
|
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
|
return param_size
|
|
|
|
|
|
def calculate_fwd_in(n: Node) -> int:
|
|
"""A helper function to calculate `fwd_in`
|
|
|
|
Args:
|
|
n (Node): a node from the graph
|
|
|
|
Returns:
|
|
fwd_in (int): the result of `fwd_in`
|
|
"""
|
|
return activation_size(n.meta["fwd_in"])
|
|
|
|
|
|
def calculate_fwd_tmp(n: Node) -> int:
|
|
"""A helper function to calculate `fwd_tmp`
|
|
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
|
|
|
Args:
|
|
n (Node): a node from the graph
|
|
|
|
Returns:
|
|
fwd_tmp (int): the result of `fwd_tmp`
|
|
"""
|
|
|
|
def is_relu_like_node(n: Node) -> bool:
|
|
"""Check if a node is a ReLU-like node.
|
|
ReLU-like nodes have the following properties:
|
|
- They are either `call_function` or `call_module`
|
|
- Their output tensors are directly saved for backward
|
|
- Their input tensors are not saved for backward
|
|
|
|
An example is `torch.nn.functional.softmax` which has (forward + backward):
|
|
def forward(self, input_2):
|
|
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
|
|
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
|
|
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
|
|
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
|
|
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
|
|
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
|
|
|
|
Args:
|
|
n (Node): A node from the graph
|
|
|
|
Returns:
|
|
bool: Whether the node is a ReLU-like node
|
|
"""
|
|
if n.op == 'call_function':
|
|
return n.target in OUTPUT_SAVED_OPS
|
|
elif n.op == 'call_module':
|
|
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
|
|
return False
|
|
|
|
if not is_relu_like_node(n):
|
|
return activation_size(n.meta["fwd_tmp"])
|
|
return 0
|
|
|
|
|
|
def calculate_fwd_out(n: Node) -> int:
|
|
"""A helper function to calculate `fwd_out`
|
|
|
|
Args:
|
|
n (Node): a node from the graph
|
|
|
|
Returns:
|
|
fwd_out (int): the result of `fwd_out`
|
|
"""
|
|
|
|
def intersect(a, b):
|
|
return {k: a[k] for k in a if k in b}
|
|
|
|
fwd_in = dict()
|
|
for u in n.users:
|
|
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
|
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
|
return activation_size(intersect(fwd_in, fwd_out))
|
|
|
|
|
|
def is_inplace(n: Node):
|
|
"""Get the inplace argument from torch.fx.Node
|
|
|
|
Args:
|
|
node (Node): torch.fx.Node
|
|
|
|
Returns:
|
|
bool: indicates whether this op is inplace
|
|
"""
|
|
inplace = False
|
|
if n.op == "call_function":
|
|
inplace = n.kwargs.get("inplace", False)
|
|
if is_compatible_with_meta():
|
|
from .constants import ALIAS_ATEN
|
|
if n.target in ALIAS_ATEN:
|
|
inplace = True
|
|
elif n.op == "call_module":
|
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
|
|
|
return inplace
|