diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 0d09ed9f0..1e16ab9bd 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -1,13 +1,34 @@ import copy +from typing import Dict, List, Tuple -from .utils import ( - find_idx_by_name, - get_node_shape, -) +from torch.fx.node import Node + +from .utils import find_idx_by_name, get_node_shape class TraceIndice(object): - def __init__(self, node_list) -> None: + """ + Trace all indice infomation for every node. + + Indice is a logical concept. Equal dims can been treated as one indice. + eg. dim(x1) = [a, b, c] + dim(x2) = [d, e, f] + and we have x3 = x1 * x2. + then a=d, b=e, c=f, due to the broadcast property, + dim(x1)=dim(x2)=dim(x3)=[a, b, c] + This class will record every node's dims' indice, compute and source. + + Attibutes: + node_list (List) + indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}] + indice_view_list (Dict): not used for now + indice_count (int): record indice number + + Args: + node_list (List) + """ + + def __init__(self, node_list: List) -> None: self.node_list = node_list self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {}