mirror of https://github.com/hpcaitech/ColossalAI
finishi codegen on msa
parent
6d99994a7a
commit
2b4ebcc278
212
chunk_codegen.py
212
chunk_codegen.py
|
@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses):
|
|||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
class FlowTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.flow_trace = {}
|
||||
|
||||
def _add_trace(self, name):
|
||||
self.flow_trace[name] = []
|
||||
|
||||
def _add_node(self, trace_name, node):
|
||||
self.flow_trace[trace_name].append({'node': node, 'inside_depend': [], 'outside_depend': []})
|
||||
|
||||
def _add_inside_depend(self, flow_name, node, inside_depend_node):
|
||||
for i in self.flow_trace[flow_name]:
|
||||
if i['node'] == node:
|
||||
i['inside_depend'].append(inside_depend_node)
|
||||
return
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depend_trace):
|
||||
for i in self.flow_trace[flow_name]:
|
||||
if i['node'] == node:
|
||||
i['outside_depend'].append({outside_depend_trace: outside_depend_node})
|
||||
return
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _init_trace(self):
|
||||
for i in self.nodes_list:
|
||||
if i.op == 'placeholder':
|
||||
self._add_trace(i.name)
|
||||
self._add_node(i.name, i)
|
||||
|
||||
def _is_non_compute_node(self, node):
|
||||
if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \
|
||||
any(i in node.name for i in ['getitem', 'getattr']):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_non_compute_node_except_placeholder(self, node):
|
||||
if any(i in node.op for i in ['get_attr', 'output']) or \
|
||||
any(i in node.name for i in ['getitem', 'getattr']):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _find_flow_for_node(self, node):
|
||||
if type(self.nodes_list[0]) != type(node):
|
||||
return None
|
||||
if self._is_non_compute_node_except_placeholder(node):
|
||||
return None
|
||||
for name, trace in self.flow_trace.items():
|
||||
for i in trace:
|
||||
if node == i['node']:
|
||||
return name
|
||||
if any(i in node.name for i in ["ones_like"]):
|
||||
self._add_trace(node.name)
|
||||
self._add_node(node.name, node)
|
||||
return node.name
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _find_first_valid_flow(self, flow):
|
||||
for i in flow:
|
||||
if i is not None:
|
||||
return i
|
||||
raise RuntimeError("invalid flow")
|
||||
|
||||
def find_node_flow(self, node):
|
||||
for name, trace in self.flow_trace.items():
|
||||
for i in trace:
|
||||
if node == i['node']:
|
||||
return name, i
|
||||
raise RuntimeError("invalid node")
|
||||
|
||||
def get_flow_mix(self, node):
|
||||
if self._is_non_compute_node(node):
|
||||
return None
|
||||
_, node_trace = self.find_node_flow(node)
|
||||
if len(node_trace['outside_depend']) == 0:
|
||||
return None
|
||||
elif len(node_trace['outside_depend']) > 1:
|
||||
raise NotImplementedError
|
||||
vars = list(node_trace['outside_depend'][0].values())[0]
|
||||
return vars
|
||||
|
||||
def get_same_flow_node(self, node_list, node):
|
||||
name, _ = self.find_node_flow(node)
|
||||
result = []
|
||||
for i in self.flow_trace[name]:
|
||||
if i['node'] in node_list:
|
||||
result.append(i['node'])
|
||||
return result
|
||||
|
||||
def trace_flow(self):
|
||||
# init trace
|
||||
self._init_trace()
|
||||
|
||||
for node in self.nodes_list:
|
||||
# skip if non compute node
|
||||
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
|
||||
or self._is_non_compute_node(node):
|
||||
continue
|
||||
|
||||
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
||||
|
||||
node_domin_flow = self._find_first_valid_flow(node_input_flows)
|
||||
self._add_node(node_domin_flow, node)
|
||||
for node_input_flow, arg in zip(node_input_flows, node.args):
|
||||
if node_input_flow is None:
|
||||
continue
|
||||
elif node_input_flow == node_domin_flow:
|
||||
self._add_inside_depend(node_domin_flow, node, arg)
|
||||
else:
|
||||
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
|
||||
return self.flow_trace
|
||||
|
||||
|
||||
class IndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
|
@ -428,7 +543,7 @@ class IndexTracer(object):
|
|||
if merge_from in trace['idx']:
|
||||
trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']]
|
||||
|
||||
def trace_node_idx(self):
|
||||
def trace_index(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
if node.op == 'placeholder':
|
||||
self._assign_all_index(node, idx)
|
||||
|
@ -684,7 +799,9 @@ class ChunkRegionSearch(object):
|
|||
self.node_list = list(gm.graph.nodes)
|
||||
self.memory_estimator = MemoryEstimator()
|
||||
self.index_tracer = IndexTracer(gm)
|
||||
self.index_tracer.trace_node_idx()
|
||||
self.index_tracer.trace_index()
|
||||
self.flow_tracer = FlowTracer(gm)
|
||||
self.flow_tracer.trace_flow()
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
max_value = max(mem_peak)
|
||||
|
@ -729,7 +846,7 @@ class ChunkRegionSearch(object):
|
|||
raise RuntimeError()
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _not_compute(self, trace, chunk_range, dim_idx):
|
||||
def _is_not_compute(self, trace, chunk_range, dim_idx):
|
||||
if trace['idx'][dim_idx] not in trace['compute']:
|
||||
return True
|
||||
if trace['idx'][dim_idx] in trace['compute'] and \
|
||||
|
@ -737,6 +854,56 @@ class ChunkRegionSearch(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx):
|
||||
inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
|
||||
chunk_info = {'inputs': inputs, 'outputs': outputs}
|
||||
flow_flag = False
|
||||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
mix_flow_var = self.flow_tracer.get_flow_mix(node)
|
||||
if mix_flow_var is None:
|
||||
continue
|
||||
|
||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
||||
# element-wise op requires dim to be equal in every dim
|
||||
if any(n in node.name for n in ['mul', 'add']):
|
||||
for i in node.args:
|
||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
||||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO need to move that flow out of the chunk
|
||||
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
|
||||
flow_flag = True
|
||||
for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
|
||||
chunk_info['inputs'].remove(i)
|
||||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_flag = False
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
return flow_flag, chunk_info
|
||||
|
||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
before_trace = input_trace[start_idx]
|
||||
after_trace = output_trace[end_idx]
|
||||
free_dim = []
|
||||
chunk_infos = []
|
||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
||||
if not (before_trace['idx'][i] == after_trace['idx'][i] and
|
||||
self._is_not_compute(before_trace, (start_idx, end_idx), i) and
|
||||
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
|
||||
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
|
||||
continue
|
||||
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
|
||||
if flow_flag == None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
free_dim.append(i)
|
||||
return free_dim, chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||
|
@ -748,27 +915,22 @@ class ChunkRegionSearch(object):
|
|||
else:
|
||||
input_trace.append(None)
|
||||
|
||||
for before_idx in range(max_chunk_region[0], peak_node):
|
||||
for after_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
for start_idx in range(max_chunk_region[0], peak_node):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if any(op in ['placeholder', 'get_attr', 'output'] for op in
|
||||
[self.node_list[before_idx].op, self.node_list[after_idx].op]):
|
||||
[self.node_list[start_idx].op, self.node_list[end_idx].op]):
|
||||
continue
|
||||
if any(any(i in name for i in ['getitem', 'getattr']) for name in
|
||||
[self.node_list[before_idx].name, self.node_list[after_idx].name]):
|
||||
[self.node_list[start_idx].name, self.node_list[end_idx].name]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
before_trace = input_trace[before_idx]
|
||||
after_trace = output_trace[after_idx]
|
||||
free_dim = []
|
||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
||||
if (before_trace['idx'][i] == after_trace['idx'][i] and
|
||||
self._not_compute(before_trace, (before_idx, after_idx), i) and
|
||||
self._not_compute(after_trace, (before_idx, after_idx), i) and
|
||||
self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1):
|
||||
free_dim.append(i)
|
||||
possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
|
||||
free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(free_dim) > 0:
|
||||
free_dim = [free_dim[0]]
|
||||
chunk_info = [chunk_info[0]]
|
||||
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||
|
@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
chunk_search = chunk_region_search.search_region()
|
||||
chunk_regions = [i['region'] for i in chunk_search]
|
||||
chunk_dims = [i['dim'] for i in chunk_search]
|
||||
chunk_infos = [i['chunk_info'] for i in chunk_search]
|
||||
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
chunk_outputs = []
|
||||
chunk_inputs = [[j['inputs'][0] for j in i] for i in chunk_infos]
|
||||
chunk_outputs = [[j['outputs'][0] for j in i] for i in chunk_infos]
|
||||
within_chunk_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(chunk_regions):
|
||||
offload_node_list = node_list[start:end + 1]
|
||||
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
chunk_inputs.append(inputs)
|
||||
chunk_outputs.append(outputs)
|
||||
# for idx, (start, end) in enumerate(chunk_regions):
|
||||
# offload_node_list = node_list[start:end + 1]
|
||||
# inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
# chunk_inputs.append(inputs)
|
||||
# chunk_outputs.append(outputs)
|
||||
|
||||
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
|
||||
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
|
||||
chunk_inputs_names = []
|
||||
|
|
Loading…
Reference in New Issue