Browse Source

rename

pull/2364/head
oahzxl 2 years ago
parent
commit
1bb1f2ad89
  1. 4
      colossalai/autochunk/search_chunk.py
  2. 50
      colossalai/autochunk/trace_indice.py

4
colossalai/autochunk/search_chunk.py

@ -158,11 +158,11 @@ class SearchChunk(object):
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["idx"]):
for end_dim, _ in enumerate(end_trace["indice"]):
if len(start_traces) > 1:
continue
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["idx"]):
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (
get_node_shape(end_node)[end_dim] == 1

50
colossalai/autochunk/trace_indice.py

@ -19,12 +19,12 @@ class TraceIndice(object):
for n in self.node_list:
if get_node_shape(n) != None:
cur_trace = {
"idx": [None for _ in range(len(get_node_shape(n)))],
"indice": [None for _ in range(len(get_node_shape(n)))],
"compute": [[] for _ in range(len(get_node_shape(n)))],
"source": [{} for _ in range(len(get_node_shape(n)))],
}
else:
cur_trace = {"idx": [], "compute": [], "source": []}
cur_trace = {"indice": [], "compute": [], "source": []}
indice_trace_list.append(cur_trace)
return indice_trace_list
@ -39,12 +39,12 @@ class TraceIndice(object):
return self.indice_count
def _del_dim(self, idx, dim_idx):
self.indice_trace_list[idx]["idx"].pop(dim_idx)
self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx):
self.indice_trace_list[node_idx]["idx"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
@ -58,7 +58,7 @@ class TraceIndice(object):
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim]
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
node_from_trace["compute"][node_from_dim]
)
@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node
"""
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["idx"]
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node):
"""
@ -195,7 +195,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_index_as_input(self, node, node_idx, input_node=None):
def _assign_indice_as_input(self, node, node_idx, input_node=None):
"""
Assign node's trace as its input node.
@ -206,10 +206,10 @@ class TraceIndice(object):
if input_node == None:
input_node = node.args[0]
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["idx"]
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.indice_trace_list[node_idx]["idx"] = new_idx_trace
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
self._inherit_all_computation(input_node, node)
@ -225,7 +225,7 @@ class TraceIndice(object):
new_trace = []
for _ in shape:
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["idx"] = new_trace
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node, node_idx):
"""
@ -240,7 +240,7 @@ class TraceIndice(object):
input_node = node.args[0]
tranpose_dim = node.args[1:]
self._assign_index_as_input(node, node_idx, input_node)
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
@ -257,7 +257,7 @@ class TraceIndice(object):
permute_dim = node.args[1:]
input_node = node.args[0]
self._assign_index_as_input(node, node_idx, input_node)
self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
@ -278,7 +278,7 @@ class TraceIndice(object):
else:
input_node, weight, bias = node.args
self._assign_index_as_input(node, node_idx)
self._assign_indice_as_input(node, node_idx)
self._inherit_indice(weight, 1, node, -1)
self._mark_computation(node, node_idx, [-1])
@ -301,7 +301,7 @@ class TraceIndice(object):
matmul_left, matmul_right = node.args
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_index_as_input(node, node_idx, matmul_left)
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2])
@ -318,7 +318,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, idx)
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1])
def _assign_elementwise_indice(self, node, idx):
@ -331,7 +331,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, idx)
self._assign_indice_as_input(node, idx)
nodes_in = []
for node_in in node.args:
if type(node_in) == type(node):
@ -346,7 +346,7 @@ class TraceIndice(object):
self._mark_indice_equal(nodes_in[0], i, nodes_in[1], i)
def _assgin_no_change_indice(self, node, idx):
self._assign_index_as_input(node, idx)
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node)
@ -398,7 +398,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, idx)
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node, node_idx):
@ -411,7 +411,7 @@ class TraceIndice(object):
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_index_as_input(node, node_idx)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1])
def _assign_dropout_indice(self, node, node_idx):
@ -423,7 +423,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, node_idx)
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node, node_idx):
"""
@ -497,7 +497,7 @@ class TraceIndice(object):
# get new index
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_index_as_input(node, node_idx, origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
@ -516,7 +516,7 @@ class TraceIndice(object):
view_dict = {
"idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from,
"idx_to": [self.indice_trace_list[node_idx]["idx"][i] for i in dim_to],
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
"dim_to": dim_to,
}
self.indice_view_list[node] = view_dict
@ -528,9 +528,9 @@ class TraceIndice(object):
merge_to = min(idx)
merge_from = max(idx)
for trace in self.indice_trace_list:
if merge_from in trace["idx"]:
trace["idx"] = [
merge_to if i == merge_from else i for i in trace["idx"]
if merge_from in trace["indice"]:
trace["indice"] = [
merge_to if i == merge_from else i for i in trace["indice"]
]
def trace_index(self):

Loading…
Cancel
Save