|
|
|
@ -1,13 +1,34 @@
|
|
|
|
|
import copy
|
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
from .utils import (
|
|
|
|
|
find_idx_by_name,
|
|
|
|
|
get_node_shape,
|
|
|
|
|
)
|
|
|
|
|
from torch.fx.node import Node
|
|
|
|
|
|
|
|
|
|
from .utils import find_idx_by_name, get_node_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TraceIndice(object):
|
|
|
|
|
def __init__(self, node_list) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Trace all indice infomation for every node.
|
|
|
|
|
|
|
|
|
|
Indice is a logical concept. Equal dims can been treated as one indice.
|
|
|
|
|
eg. dim(x1) = [a, b, c]
|
|
|
|
|
dim(x2) = [d, e, f]
|
|
|
|
|
and we have x3 = x1 * x2.
|
|
|
|
|
then a=d, b=e, c=f, due to the broadcast property,
|
|
|
|
|
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
|
|
|
|
This class will record every node's dims' indice, compute and source.
|
|
|
|
|
|
|
|
|
|
Attibutes:
|
|
|
|
|
node_list (List)
|
|
|
|
|
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
|
|
|
|
indice_view_list (Dict): not used for now
|
|
|
|
|
indice_count (int): record indice number
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node_list (List)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, node_list: List) -> None:
|
|
|
|
|
self.node_list = node_list
|
|
|
|
|
self.indice_trace_list = self._init_indice_trace_list()
|
|
|
|
|
self.indice_view_list = {}
|
|
|
|
|