Browse Source

add doc for trace indice

pull/2364/head
oahzxl 2 years ago
parent
commit
1be0ac3cbf
  1. 31
      colossalai/autochunk/trace_indice.py

31
colossalai/autochunk/trace_indice.py

@ -1,13 +1,34 @@
import copy
from typing import Dict, List, Tuple
from .utils import (
find_idx_by_name,
get_node_shape,
)
from torch.fx.node import Node
from .utils import find_idx_by_name, get_node_shape
class TraceIndice(object):
def __init__(self, node_list) -> None:
"""
Trace all indice infomation for every node.
Indice is a logical concept. Equal dims can been treated as one indice.
eg. dim(x1) = [a, b, c]
dim(x2) = [d, e, f]
and we have x3 = x1 * x2.
then a=d, b=e, c=f, due to the broadcast property,
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
This class will record every node's dims' indice, compute and source.
Attibutes:
node_list (List)
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
indice_view_list (Dict): not used for now
indice_count (int): record indice number
Args:
node_list (List)
"""
def __init__(self, node_list: List) -> None:
self.node_list = node_list
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}

Loading…
Cancel
Save