import torch from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta from .memory_utils import activation_size if is_compatible_with_meta(): from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] @compatibility(is_backward_compatible=False) def calculate_fwd_in(n: Node) -> int: """A helper function to calculate `fwd_in` (with sharding spec) Args: n (Node): a node from the graph Returns: fwd_in (int): the result of `fwd_in` """ # TODO(super-dainiu): should divide the memory by sharding spec return activation_size(n.meta["fwd_in"]) @compatibility(is_backward_compatible=False) def calculate_fwd_tmp(n: Node) -> int: """A helper function to calculate `fwd_tmp` (with sharding spec) Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. Args: n (Node): a node from the graph Returns: fwd_tmp (int): the result of `fwd_tmp` """ # TODO(super-dainiu): should divide the memory by sharding spec def is_relu_like_node(n: Node) -> bool: """Check if a node is a ReLU-like node. ReLU-like nodes have the following properties: - They are either `call_function` or `call_module` - Their output tensors are directly saved for backward - Their input tensors are not saved for backward An example is `torch.nn.functional.softmax` which has (forward + backward): def forward(self, input_2): _softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None) detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None _softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None Args: n (Node): A node from the graph Returns: bool: Whether the node is a ReLU-like node """ if n.op == 'call_function': return n.target in OUTPUT_SAVED_OPS elif n.op == 'call_module': return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return False if not is_relu_like_node(n): return activation_size(n.meta["fwd_tmp"]) return 0 @compatibility(is_backward_compatible=False) def calculate_fwd_out(n: Node) -> int: """A helper function to calculate `fwd_out` (with sharding spec) Args: n (Node): a node from the graph Returns: fwd_out (int): the result of `fwd_out` """ # TODO(super-dainiu): should divide the memory by sharding spec def intersect(a, b): return {k: a[k] for k in a if k in b} fwd_in = dict() for u in n.users: fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)}) fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)} return activation_size(intersect(fwd_in, fwd_out)) def calculate_fwd_time(n: Node) -> float: """A helper function to calculate `fwd_time` (with sharding spec) Args: n (Node): a node from the graph Returns: fwd_time (float): the result of `fwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs return n.meta["fwd_flop"] def calculate_bwd_time(n: Node) -> float: """A helper function to calculate `bwd_time` (with sharding spec) Args: n (Node): a node from the graph Returns: bwd_time (float): the result of `bwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs return n.meta["bwd_flop"]