Browse Source

rename

pull/2364/head
oahzxl 2 years ago
parent
commit
d914a21d64
  1. 31
      colossalai/autochunk/trace_indice.py

31
colossalai/autochunk/trace_indice.py

@ -111,21 +111,6 @@ class TraceIndice(object):
if j not in node_to_compute[i]:
node_to_compute[i].append(j)
def _mark_indice_equal(self, node1, dim1, node2, dim2):
"""
Mark 2 indice to be equal.
Args:
idx1 (int): indice count.
idx2 (int): indice count.
"""
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
# if node1_idx > node2_idx:
# self._add_source(node2, dim2, node1, dim1)
# else:
# self._add_source(node1, dim1, node2, dim2)
def _mark_computation(self, node, idx, dim):
"""
Mark some dims of node as computed.
@ -273,19 +258,14 @@ class TraceIndice(object):
node_idx (int)
"""
if len(node.args) == 2:
input_node, weight = node.args
bias = None
_, weight = node.args
else:
input_node, weight, bias = node.args
_, weight, _ = node.args
self._assign_indice_as_input(node, node_idx)
self._inherit_indice(weight, 1, node, -1)
self._mark_computation(node, node_idx, [-1])
self._mark_indice_equal(input_node, -1, weight, 0)
if bias:
self._mark_indice_equal(input_node, -1, bias, 0)
def _assign_matmul_indice(self, node, node_idx):
"""
@ -306,7 +286,6 @@ class TraceIndice(object):
self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
self._mark_indice_equal(matmul_left, -1, matmul_right, -2)
def _assign_layernorm_indice(self, node, idx):
"""
@ -338,12 +317,6 @@ class TraceIndice(object):
nodes_in.append(node_in)
self._mark_computation_from_node(node_in, node)
assert len(nodes_in) <= 2
if len(nodes_in) == 2:
node_in0_shape = get_node_shape(nodes_in[0])
node_in1_shape = get_node_shape(nodes_in[1])
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
if node_in0_shape[i] == node_in1_shape[i]:
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)

Loading…
Cancel
Save