|
|
|
@ -1,13 +1,34 @@
|
|
|
|
|
import copy |
|
|
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
|
|
|
|
from .utils import ( |
|
|
|
|
find_idx_by_name, |
|
|
|
|
get_node_shape, |
|
|
|
|
) |
|
|
|
|
from torch.fx.node import Node |
|
|
|
|
|
|
|
|
|
from .utils import find_idx_by_name, get_node_shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TraceIndice(object): |
|
|
|
|
def __init__(self, node_list) -> None: |
|
|
|
|
""" |
|
|
|
|
Trace all indice infomation for every node. |
|
|
|
|
|
|
|
|
|
Indice is a logical concept. Equal dims can been treated as one indice. |
|
|
|
|
eg. dim(x1) = [a, b, c] |
|
|
|
|
dim(x2) = [d, e, f] |
|
|
|
|
and we have x3 = x1 * x2. |
|
|
|
|
then a=d, b=e, c=f, due to the broadcast property, |
|
|
|
|
dim(x1)=dim(x2)=dim(x3)=[a, b, c] |
|
|
|
|
This class will record every node's dims' indice, compute and source. |
|
|
|
|
|
|
|
|
|
Attibutes: |
|
|
|
|
node_list (List) |
|
|
|
|
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}] |
|
|
|
|
indice_view_list (Dict): not used for now |
|
|
|
|
indice_count (int): record indice number |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
node_list (List) |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, node_list: List) -> None: |
|
|
|
|
self.node_list = node_list |
|
|
|
|
self.indice_trace_list = self._init_indice_trace_list() |
|
|
|
|
self.indice_view_list = {} |
|
|
|
|