2022-10-11 03:03:35 +00:00
|
|
|
# for PyTorch 1.11 compatibility uses
|
2022-10-18 02:44:23 +00:00
|
|
|
from typing import Dict, List, Tuple, Union
|
|
|
|
|
2022-10-11 03:03:35 +00:00
|
|
|
import torch
|
2022-10-18 02:44:23 +00:00
|
|
|
from torch.fx import GraphModule, Node
|
|
|
|
|
|
|
|
from ..._compatibility import compatibility
|
2022-10-11 03:03:35 +00:00
|
|
|
|
|
|
|
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
|
|
|
|
|
|
|
|
2022-10-18 02:44:23 +00:00
|
|
|
@compatibility(is_backward_compatible=True)
|
2022-10-11 03:03:35 +00:00
|
|
|
def calculate_fwd_in(n: Node) -> bool:
|
|
|
|
"""A helper function to calculate `fwd_in`
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (Node): a node from the graph
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
save_fwd_in (bool): the result of `save_fwd_in`
|
|
|
|
"""
|
|
|
|
return n.meta['save_fwd_in']
|
|
|
|
|
|
|
|
|
2022-10-18 02:44:23 +00:00
|
|
|
@compatibility(is_backward_compatible=True)
|
2022-10-11 03:03:35 +00:00
|
|
|
def calculate_fwd_tmp(n: Node) -> int:
|
|
|
|
"""A helper function to calculate `fwd_tmp`
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (Node): a node from the graph
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
fwd_tmp (int): the result of `fwd_tmp`
|
|
|
|
"""
|
|
|
|
return n.meta["fwd_mem_tmp"]
|
|
|
|
|
|
|
|
|
2022-10-18 02:44:23 +00:00
|
|
|
@compatibility(is_backward_compatible=True)
|
2022-10-11 03:03:35 +00:00
|
|
|
def calculate_fwd_out(n: Node) -> int:
|
|
|
|
"""A helper function to calculate `fwd_out`
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (Node): a node from the graph
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
fwd_out (int): the result of `fwd_out`
|
|
|
|
"""
|
|
|
|
return n.meta['fwd_mem_out']
|