mirror of https://github.com/hpcaitech/ColossalAI
pass outproduct mean
parent
979e61db92
commit
929445116a
307
chunk_codegen.py
307
chunk_codegen.py
|
@ -16,16 +16,31 @@ def _delete_free_var_from_last_use(user_to_last_uses):
|
|||
if n.op == 'placeholder':
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def _get_node_shape(node):
|
||||
if hasattr(node.meta['tensor_meta'], "shape"):
|
||||
return node.meta['tensor_meta'].shape
|
||||
return None
|
||||
|
||||
|
||||
def _is_non_compute_node(node):
|
||||
if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \
|
||||
any(i in node.name for i in ['getitem', 'getattr']):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_non_compute_node_except_placeholder(node):
|
||||
if any(i in node.op for i in ['get_attr', 'output']) or \
|
||||
any(i in node.name for i in ['getitem', 'getattr']):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FlowTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.flow_trace = {}
|
||||
|
||||
def _add_trace(self, name):
|
||||
|
@ -49,7 +64,7 @@ class FlowTracer(object):
|
|||
raise RuntimeError("node not found")
|
||||
|
||||
def _init_trace(self):
|
||||
for i in self.nodes_list:
|
||||
for i in self.node_list:
|
||||
if i.op == 'placeholder':
|
||||
self._add_trace(i.name)
|
||||
self._add_node(i.name, i)
|
||||
|
@ -67,7 +82,7 @@ class FlowTracer(object):
|
|||
return False
|
||||
|
||||
def _find_flow_for_node(self, node):
|
||||
if type(self.nodes_list[0]) != type(node):
|
||||
if type(self.node_list[0]) != type(node):
|
||||
return None
|
||||
if self._is_non_compute_node_except_placeholder(node):
|
||||
return None
|
||||
|
@ -117,7 +132,7 @@ class FlowTracer(object):
|
|||
# init trace
|
||||
self._init_trace()
|
||||
|
||||
for node in self.nodes_list:
|
||||
for node in self.node_list:
|
||||
# skip if non compute node
|
||||
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
|
||||
or self._is_non_compute_node(node):
|
||||
|
@ -136,6 +151,41 @@ class FlowTracer(object):
|
|||
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
|
||||
return self.flow_trace
|
||||
|
||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
|
||||
chunk_info = {'region': (start_idx, end_idx),
|
||||
'inputs': inputs, 'inputs_dim': start_dim,
|
||||
'outputs': outputs, 'outputs_dim': end_dim,
|
||||
'args': {}}
|
||||
flow_flag = False
|
||||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
mix_flow_var = self.get_flow_mix(node)
|
||||
if mix_flow_var is None:
|
||||
continue
|
||||
|
||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
||||
# element-wise op requires dim to be equal in every dim
|
||||
if any(n in node.name for n in ['mul', 'add']):
|
||||
for i in node.args:
|
||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
||||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO need to move that flow out of the chunk
|
||||
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
|
||||
flow_flag = True
|
||||
for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
|
||||
chunk_info['inputs'].remove(i)
|
||||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_flag = False
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
return flow_flag, chunk_info
|
||||
|
||||
|
||||
class IndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
|
@ -153,7 +203,7 @@ class IndexTracer(object):
|
|||
cur_trace = {
|
||||
'idx': [None for _ in range(len(_get_node_shape(n)))],
|
||||
'compute': [[] for _ in range(len(_get_node_shape(n)))],
|
||||
'source': [[] for _ in range(len(_get_node_shape(n)))],
|
||||
'source': [{} for _ in range(len(_get_node_shape(n)))],
|
||||
}
|
||||
else:
|
||||
cur_trace = {'idx': [], 'compute': [], 'source': []}
|
||||
|
@ -178,7 +228,7 @@ class IndexTracer(object):
|
|||
def _add_dim(self, idx, dim_idx):
|
||||
self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
|
||||
self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
|
||||
self.idx_trace_list[idx]['source'].insert(dim_idx, [])
|
||||
self.idx_trace_list[idx]['source'].insert(dim_idx, {})
|
||||
|
||||
def _transform_index(self, node, node_dim):
|
||||
node_idx = self._find_idx_trace_from_node(node)
|
||||
|
@ -192,10 +242,7 @@ class IndexTracer(object):
|
|||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
|
||||
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
node_to_trace['source'][node_to_dim] = []
|
||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
|
@ -205,14 +252,16 @@ class IndexTracer(object):
|
|||
self._add_source(node_from, i, node_to, i)
|
||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
||||
if init:
|
||||
node_to_trace['source'][node_to_dim] = {}
|
||||
node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim
|
||||
node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim])
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
if exclude == None:
|
||||
|
@ -485,11 +534,11 @@ class IndexTracer(object):
|
|||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
|
||||
|
||||
for i in sum_index:
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if i in left_str:
|
||||
self._mark_computation(node, idx, left_str.index(i))
|
||||
break
|
||||
# for i in sum_index:
|
||||
# for left_idx, left_str in enumerate(left):
|
||||
# if i in left_str:
|
||||
# self._mark_computation(node, idx, left_str.index(i))
|
||||
# break
|
||||
|
||||
def _assign_softmax_index(self, node, idx):
|
||||
"""
|
||||
|
@ -679,18 +728,56 @@ class IndexTracer(object):
|
|||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
# self._merge_equal_idx()
|
||||
|
||||
def check_index(self, trace_idx, start_idx, end_idx):
|
||||
for i in range(start_idx, end_idx + 1):
|
||||
cur_idx = self.idx_trace_list[i]['idx']
|
||||
cur_compute = self.idx_trace_list[i]['compute']
|
||||
if trace_idx in cur_compute:
|
||||
for j in cur_compute[trace_idx]:
|
||||
if j < start_idx or j > end_idx:
|
||||
return False
|
||||
# same_idx = [1 if j == trace_idx else 0 for j in cur_idx]
|
||||
# if sum(same_idx) > 1:
|
||||
# return False
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
Check 2 given index: one index should be source of the other
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
|
||||
end_node_trace = self._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace['source'][end_dim]
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and node_dim == start_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
return False
|
||||
return False
|
||||
|
||||
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||
"""
|
||||
Check 2 given index: check they haven't been computed in the source trace.
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
end_node_trace = self._find_trace_from_node(end_node)
|
||||
end_node_compute = end_node_trace['compute'][end_dim]
|
||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||
return False
|
||||
return True
|
||||
# end_node_trace_source = end_node_trace['source'][end_dim]
|
||||
# for node_idx, node_dim in end_node_trace_source.items():
|
||||
# if node_idx < start_node_idx or node_idx > end_node_idx:
|
||||
# continue
|
||||
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
|
||||
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
|
||||
# return False
|
||||
# return True
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self) -> None:
|
||||
|
@ -951,88 +1038,81 @@ class ChunkRegionSearch(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx):
|
||||
inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
|
||||
chunk_info = {'inputs': inputs, 'outputs': outputs}
|
||||
flow_flag = False
|
||||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
mix_flow_var = self.flow_tracer.get_flow_mix(node)
|
||||
if mix_flow_var is None:
|
||||
def _check_duplicate_map(self, chunk_infos):
|
||||
dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos]
|
||||
remove_list = []
|
||||
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
|
||||
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
|
||||
if idx1 == idx2:
|
||||
continue
|
||||
|
||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
||||
# element-wise op requires dim to be equal in every dim
|
||||
if any(n in node.name for n in ['mul', 'add']):
|
||||
for i in node.args:
|
||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
||||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO need to move that flow out of the chunk
|
||||
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
|
||||
flow_flag = True
|
||||
for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
|
||||
chunk_info['inputs'].remove(i)
|
||||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_flag = False
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
return flow_flag, chunk_info
|
||||
# it means an index create 2 copy of itself
|
||||
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
||||
# TODO currently remove it, deal with this in future
|
||||
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
||||
remove_list.append(chunk_infos[idx1])
|
||||
remove_list.append(chunk_infos[idx2])
|
||||
for i in remove_list:
|
||||
if i in chunk_infos:
|
||||
chunk_infos.remove(i)
|
||||
return chunk_infos
|
||||
|
||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
before_trace = input_trace[start_idx]
|
||||
after_trace = output_trace[end_idx]
|
||||
free_dim = []
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
||||
if not (before_trace['idx'][i] == after_trace['idx'][i] and
|
||||
self._is_not_compute(before_trace, (start_idx, end_idx), i) and
|
||||
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
|
||||
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
|
||||
for end_dim, end_trace_idx in enumerate(end_trace['idx']):
|
||||
if len(start_traces) > 1:
|
||||
# TODO implement multi input chunk
|
||||
continue
|
||||
if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, start_trace_idx in enumerate(start_trace['idx']):
|
||||
# must be same trace idx
|
||||
if start_trace_idx != end_trace_idx:
|
||||
continue
|
||||
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
|
||||
if flow_flag == None:
|
||||
# dim size cannot be 1
|
||||
if _get_node_shape(end_node)[end_dim] == 1 or \
|
||||
_get_node_shape(start_node)[start_dim] == 1:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.index_tracer.check_index_source(
|
||||
start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.index_tracer.check_index_compute(
|
||||
start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# detect flow meet
|
||||
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
||||
start_idx, start_dim, end_idx, end_dim)
|
||||
if flow_flag:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
free_dim.append(i)
|
||||
return free_dim, chunk_infos
|
||||
chunk_infos = self._check_duplicate_map(chunk_infos)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||
input_trace = []
|
||||
for i, n in enumerate(self.node_list):
|
||||
if len(n.args) > 0 and n.op != 'output':
|
||||
if isinstance(n.args[0], str):
|
||||
input_idx = _find_idx_by_name(n.args[1].name, self.node_list)
|
||||
else:
|
||||
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
|
||||
input_trace.append(output_trace[input_idx])
|
||||
else:
|
||||
input_trace.append(None)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg):
|
||||
cur_trace[arg] = self.index_tracer._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node):
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if any(op in ['placeholder', 'get_attr', 'output'] for op in
|
||||
[self.node_list[start_idx].op, self.node_list[end_idx].op]):
|
||||
continue
|
||||
if any(any(i in name for i in ['getitem', 'getattr']) for name in
|
||||
[self.node_list[start_idx].name, self.node_list[end_idx].name]):
|
||||
if _is_non_compute_node(self.node_list[start_idx]) or \
|
||||
_is_non_compute_node(self.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(free_dim) > 0:
|
||||
free_dim = [free_dim[0]]
|
||||
chunk_info = [chunk_info[0]]
|
||||
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
|
||||
chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||
|
@ -1044,7 +1124,8 @@ class ChunkRegionSearch(object):
|
|||
max_region_range = i['region'][1] - i['region'][0]
|
||||
return best_regions
|
||||
|
||||
def _step_search(self, peak_node, active_node):
|
||||
def _step_search(self, mem_peak, active_node):
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
||||
|
@ -1062,19 +1143,16 @@ class ChunkRegionSearch(object):
|
|||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
chunk_region = self._step_search(peak_node, active_node)
|
||||
if chunk_region is None or len(chunk_region['dim']) == 0:
|
||||
chunk_region = self._step_search(mem_peak, active_node)
|
||||
if chunk_region is None:
|
||||
break
|
||||
|
||||
chunk_regions.append(chunk_region)
|
||||
mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.gm, [i['region'][0] for i in chunk_regions],
|
||||
[i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions))
|
||||
|
||||
[i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions))
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
|
||||
return chunk_regions
|
||||
|
||||
|
||||
|
@ -1164,6 +1242,35 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
|
||||
# if a node has an input node which is not in the node list
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if input_node not in nodes and input_node not in input_nodes \
|
||||
and not _is_non_compute_node_except_placeholder(input_node):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
# TODO it is unsafe to remove non compute node here
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if output_node not in nodes and node not in output_nodes \
|
||||
and not _is_non_compute_node_except_placeholder(input_node):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_idx_by_name(name, nodes_list):
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
|
|
Loading…
Reference in New Issue