Browse Source

rename in doc

pull/2364/head
oahzxl 2 years ago
parent
commit
a4ed5b0d0d
  1. 66
      colossalai/autochunk/trace_indice.py

66
colossalai/autochunk/trace_indice.py

@ -33,7 +33,7 @@ class TraceIndice(object):
Update the count and return it. To record the idx number.
Returns:
idx_count: int
indice_count: int
"""
self.indice_count += 1
return self.indice_count
@ -113,11 +113,11 @@ class TraceIndice(object):
def _mark_indice_equal(self, node1, dim1, node2, dim2):
"""
Mark 2 index to be equal.
Mark 2 indice to be equal.
Args:
idx1 (int): index count.
idx2 (int): index count.
idx1 (int): indice count.
idx2 (int): indice count.
"""
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
@ -215,7 +215,7 @@ class TraceIndice(object):
def _assign_all_indice(self, node, node_idx):
"""
Add new index for all node's dims.
Add new indice for all node's dims.
Args:
node (node)
@ -229,7 +229,7 @@ class TraceIndice(object):
def _assign_transpose_indice(self, node, node_idx):
"""
Assign index for transpose op.
Assign indice for transpose op.
1. swap input's dim according to transpose args
2. inherit input's computation
@ -246,7 +246,7 @@ class TraceIndice(object):
def _assign_permute_indice(self, node, node_idx):
"""
Assign index for permute op.
Assign indice for permute op.
1. swap input's dim according to permute args
2. inherit input's computation
@ -263,9 +263,9 @@ class TraceIndice(object):
def _assign_linear_indice(self, node, node_idx):
"""
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
2. mark equal for input node last index, weight first dim and bias dim.
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
2. mark equal for input node last indice, weight first dim and bias dim.
3. inherit input's computation, mark computation for last dim.
Args:
@ -289,9 +289,9 @@ class TraceIndice(object):
def _assign_matmul_indice(self, node, node_idx):
"""
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
Args:
@ -310,8 +310,8 @@ class TraceIndice(object):
def _assign_layernorm_indice(self, node, idx):
"""
Assign index for layernorm op.
1. assign index as input node
Assign indice for layernorm op.
1. assign indice as input node
2. inherit computation and mark last 2 dims as computed.
Args:
@ -323,8 +323,8 @@ class TraceIndice(object):
def _assign_elementwise_indice(self, node, idx):
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
Assign indice for element-wise op (eg. relu sigmoid add mul).
1. assign indice as input node
2. inherit computation from all input nodes.
Args:
@ -353,7 +353,7 @@ class TraceIndice(object):
def _assign_einsum_indice(self, node, idx):
"""
Assign index for einsum op.
Assign indice for einsum op.
Args:
node (node)
@ -371,8 +371,6 @@ class TraceIndice(object):
for c in i:
all_index.append(c)
all_index = set(all_index)
free_index = set([i for i in right])
sum_index = all_index - free_index
for right_idx, right_indice in enumerate(right):
for left_idx, left_str in enumerate(left):
@ -382,16 +380,10 @@ class TraceIndice(object):
input_nodes[left_idx], source_idx, node, right_idx
)
# for i in sum_index:
# for left_idx, left_str in enumerate(left):
# if i in left_str:
# self._mark_computation(node, idx, left_str.index(i))
# break
def _assign_softmax_indice(self, node, idx):
"""
Assign index for softmax op.
1. assign index as input node
Assign indice for softmax op.
1. assign indice as input node
2. inherit computation and mark softmax dim as computed.
Args:
@ -403,8 +395,8 @@ class TraceIndice(object):
def _assign_unsqueeze_indice(self, node, node_idx):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
@ -416,8 +408,8 @@ class TraceIndice(object):
def _assign_dropout_indice(self, node, node_idx):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
@ -427,8 +419,8 @@ class TraceIndice(object):
def _assign_ones_like_indice(self, node, node_idx):
"""
Assign index for oneslike op.
1. assign new index for all dim
Assign indice for oneslike op.
1. assign new indice for all dim
Args:
node (node)
@ -438,10 +430,10 @@ class TraceIndice(object):
def _assign_view_reshape_indice(self, node, node_idx):
"""
Assign index for view and reshape op.
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
2. compute the real value of -1 in target shape.
3. determine changed dim, and assgin index for generated dim.
3. determine changed dim, and assgin indice for generated dim.
4. log changed dim and generated dim for restore
5. inherit computation.
6. TODO: look into view list to see whether the view is associated with other,
@ -495,7 +487,7 @@ class TraceIndice(object):
+ "view not implemented"
)
# get new index
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
dim_from.reverse()

Loading…
Cancel
Save