mirror of https://github.com/hpcaitech/ColossalAI
add doc for trace indice
parent
0b6af554df
commit
1be0ac3cbf
|
@ -1,13 +1,34 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from .utils import (
|
||||
find_idx_by_name,
|
||||
get_node_shape,
|
||||
)
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_idx_by_name, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
def __init__(self, node_list) -> None:
|
||||
"""
|
||||
Trace all indice infomation for every node.
|
||||
|
||||
Indice is a logical concept. Equal dims can been treated as one indice.
|
||||
eg. dim(x1) = [a, b, c]
|
||||
dim(x2) = [d, e, f]
|
||||
and we have x3 = x1 * x2.
|
||||
then a=d, b=e, c=f, due to the broadcast property,
|
||||
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
||||
This class will record every node's dims' indice, compute and source.
|
||||
|
||||
Attibutes:
|
||||
node_list (List)
|
||||
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
||||
indice_view_list (Dict): not used for now
|
||||
indice_count (int): record indice number
|
||||
|
||||
Args:
|
||||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List) -> None:
|
||||
self.node_list = node_list
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
|
|
Loading…
Reference in New Issue