mirror of https://github.com/hpcaitech/ColossalAI
add doc for trace indice
parent
0b6af554df
commit
1be0ac3cbf
|
@ -1,13 +1,34 @@
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from .utils import (
|
from torch.fx.node import Node
|
||||||
find_idx_by_name,
|
|
||||||
get_node_shape,
|
from .utils import find_idx_by_name, get_node_shape
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceIndice(object):
|
class TraceIndice(object):
|
||||||
def __init__(self, node_list) -> None:
|
"""
|
||||||
|
Trace all indice infomation for every node.
|
||||||
|
|
||||||
|
Indice is a logical concept. Equal dims can been treated as one indice.
|
||||||
|
eg. dim(x1) = [a, b, c]
|
||||||
|
dim(x2) = [d, e, f]
|
||||||
|
and we have x3 = x1 * x2.
|
||||||
|
then a=d, b=e, c=f, due to the broadcast property,
|
||||||
|
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
||||||
|
This class will record every node's dims' indice, compute and source.
|
||||||
|
|
||||||
|
Attibutes:
|
||||||
|
node_list (List)
|
||||||
|
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
||||||
|
indice_view_list (Dict): not used for now
|
||||||
|
indice_count (int): record indice number
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_list (List)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_list: List) -> None:
|
||||||
self.node_list = node_list
|
self.node_list = node_list
|
||||||
self.indice_trace_list = self._init_indice_trace_list()
|
self.indice_trace_list = self._init_indice_trace_list()
|
||||||
self.indice_view_list = {}
|
self.indice_view_list = {}
|
||||||
|
|
Loading…
Reference in New Issue