mirror of https://github.com/hpcaitech/ColossalAI
rename
parent
865f2e0196
commit
d914a21d64
|
@ -111,21 +111,6 @@ class TraceIndice(object):
|
|||
if j not in node_to_compute[i]:
|
||||
node_to_compute[i].append(j)
|
||||
|
||||
def _mark_indice_equal(self, node1, dim1, node2, dim2):
|
||||
"""
|
||||
Mark 2 indice to be equal.
|
||||
|
||||
Args:
|
||||
idx1 (int): indice count.
|
||||
idx2 (int): indice count.
|
||||
"""
|
||||
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
|
||||
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
|
||||
# if node1_idx > node2_idx:
|
||||
# self._add_source(node2, dim2, node1, dim1)
|
||||
# else:
|
||||
# self._add_source(node1, dim1, node2, dim2)
|
||||
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
@ -273,19 +258,14 @@ class TraceIndice(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
if len(node.args) == 2:
|
||||
input_node, weight = node.args
|
||||
bias = None
|
||||
_, weight = node.args
|
||||
else:
|
||||
input_node, weight, bias = node.args
|
||||
_, weight, _ = node.args
|
||||
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_indice_equal(input_node, -1, weight, 0)
|
||||
|
||||
if bias:
|
||||
self._mark_indice_equal(input_node, -1, bias, 0)
|
||||
|
||||
def _assign_matmul_indice(self, node, node_idx):
|
||||
"""
|
||||
|
@ -306,7 +286,6 @@ class TraceIndice(object):
|
|||
|
||||
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_indice_equal(matmul_left, -1, matmul_right, -2)
|
||||
|
||||
def _assign_layernorm_indice(self, node, idx):
|
||||
"""
|
||||
|
@ -338,12 +317,6 @@ class TraceIndice(object):
|
|||
nodes_in.append(node_in)
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
assert len(nodes_in) <= 2
|
||||
if len(nodes_in) == 2:
|
||||
node_in0_shape = get_node_shape(nodes_in[0])
|
||||
node_in1_shape = get_node_shape(nodes_in[1])
|
||||
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
|
||||
if node_in0_shape[i] == node_in1_shape[i]:
|
||||
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
|
|
Loading…
Reference in New Issue