finishi codegen on msa

pull/2364/head
oahzxl 2022-12-08 15:16:10 +08:00
parent 6d99994a7a
commit 2b4ebcc278
1 changed files with 188 additions and 24 deletions

View File

@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses):
user_to_last_uses[key].remove(n) user_to_last_uses[key].remove(n)
class FlowTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.nodes_list = list(gm.graph.nodes)
self.flow_trace = {}
def _add_trace(self, name):
self.flow_trace[name] = []
def _add_node(self, trace_name, node):
self.flow_trace[trace_name].append({'node': node, 'inside_depend': [], 'outside_depend': []})
def _add_inside_depend(self, flow_name, node, inside_depend_node):
for i in self.flow_trace[flow_name]:
if i['node'] == node:
i['inside_depend'].append(inside_depend_node)
return
raise RuntimeError("node not found")
def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depend_trace):
for i in self.flow_trace[flow_name]:
if i['node'] == node:
i['outside_depend'].append({outside_depend_trace: outside_depend_node})
return
raise RuntimeError("node not found")
def _init_trace(self):
for i in self.nodes_list:
if i.op == 'placeholder':
self._add_trace(i.name)
self._add_node(i.name, i)
def _is_non_compute_node(self, node):
if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \
any(i in node.name for i in ['getitem', 'getattr']):
return True
return False
def _is_non_compute_node_except_placeholder(self, node):
if any(i in node.op for i in ['get_attr', 'output']) or \
any(i in node.name for i in ['getitem', 'getattr']):
return True
return False
def _find_flow_for_node(self, node):
if type(self.nodes_list[0]) != type(node):
return None
if self._is_non_compute_node_except_placeholder(node):
return None
for name, trace in self.flow_trace.items():
for i in trace:
if node == i['node']:
return name
if any(i in node.name for i in ["ones_like"]):
self._add_trace(node.name)
self._add_node(node.name, node)
return node.name
raise RuntimeError("node not found")
def _find_first_valid_flow(self, flow):
for i in flow:
if i is not None:
return i
raise RuntimeError("invalid flow")
def find_node_flow(self, node):
for name, trace in self.flow_trace.items():
for i in trace:
if node == i['node']:
return name, i
raise RuntimeError("invalid node")
def get_flow_mix(self, node):
if self._is_non_compute_node(node):
return None
_, node_trace = self.find_node_flow(node)
if len(node_trace['outside_depend']) == 0:
return None
elif len(node_trace['outside_depend']) > 1:
raise NotImplementedError
vars = list(node_trace['outside_depend'][0].values())[0]
return vars
def get_same_flow_node(self, node_list, node):
name, _ = self.find_node_flow(node)
result = []
for i in self.flow_trace[name]:
if i['node'] in node_list:
result.append(i['node'])
return result
def trace_flow(self):
# init trace
self._init_trace()
for node in self.nodes_list:
# skip if non compute node
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
or self._is_non_compute_node(node):
continue
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
node_domin_flow = self._find_first_valid_flow(node_input_flows)
self._add_node(node_domin_flow, node)
for node_input_flow, arg in zip(node_input_flows, node.args):
if node_input_flow is None:
continue
elif node_input_flow == node_domin_flow:
self._add_inside_depend(node_domin_flow, node, arg)
else:
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
return self.flow_trace
class IndexTracer(object): class IndexTracer(object):
def __init__(self, gm) -> None: def __init__(self, gm) -> None:
self.gm = gm self.gm = gm
@ -428,7 +543,7 @@ class IndexTracer(object):
if merge_from in trace['idx']: if merge_from in trace['idx']:
trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']]
def trace_node_idx(self): def trace_index(self):
for idx, node in enumerate(self.nodes_list): for idx, node in enumerate(self.nodes_list):
if node.op == 'placeholder': if node.op == 'placeholder':
self._assign_all_index(node, idx) self._assign_all_index(node, idx)
@ -684,7 +799,9 @@ class ChunkRegionSearch(object):
self.node_list = list(gm.graph.nodes) self.node_list = list(gm.graph.nodes)
self.memory_estimator = MemoryEstimator() self.memory_estimator = MemoryEstimator()
self.index_tracer = IndexTracer(gm) self.index_tracer = IndexTracer(gm)
self.index_tracer.trace_node_idx() self.index_tracer.trace_index()
self.flow_tracer = FlowTracer(gm)
self.flow_tracer.trace_flow()
def _find_peak_node(self, mem_peak): def _find_peak_node(self, mem_peak):
max_value = max(mem_peak) max_value = max(mem_peak)
@ -729,7 +846,7 @@ class ChunkRegionSearch(object):
raise RuntimeError() raise RuntimeError()
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
def _not_compute(self, trace, chunk_range, dim_idx): def _is_not_compute(self, trace, chunk_range, dim_idx):
if trace['idx'][dim_idx] not in trace['compute']: if trace['idx'][dim_idx] not in trace['compute']:
return True return True
if trace['idx'][dim_idx] in trace['compute'] and \ if trace['idx'][dim_idx] in trace['compute'] and \
@ -737,6 +854,56 @@ class ChunkRegionSearch(object):
return True return True
return False return False
def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx):
inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
chunk_info = {'inputs': inputs, 'outputs': outputs}
flow_flag = False
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
mix_flow_var = self.flow_tracer.get_flow_mix(node)
if mix_flow_var is None:
continue
# if there is a flow mix, op must be in [mul, add, div, matmul]
# element-wise op requires dim to be equal in every dim
if any(n in node.name for n in ['mul', 'add']):
for i in node.args:
if type(i) == type(mix_flow_var) and i != mix_flow_var:
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
flow_flag = True
for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
chunk_info['inputs'].remove(i)
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_flag = False
break
else:
raise NotImplementedError("%s not implemented" % node.name)
return flow_flag, chunk_info
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
before_trace = input_trace[start_idx]
after_trace = output_trace[end_idx]
free_dim = []
chunk_infos = []
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
if not (before_trace['idx'][i] == after_trace['idx'][i] and
self._is_not_compute(before_trace, (start_idx, end_idx), i) and
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
continue
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
if flow_flag == None:
continue
chunk_infos.append(chunk_info)
free_dim.append(i)
return free_dim, chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region, peak_node): def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
@ -748,27 +915,22 @@ class ChunkRegionSearch(object):
else: else:
input_trace.append(None) input_trace.append(None)
for before_idx in range(max_chunk_region[0], peak_node): for start_idx in range(max_chunk_region[0], peak_node):
for after_idx in range(peak_node, max_chunk_region[1] + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes # skip non compute nodes
if any(op in ['placeholder', 'get_attr', 'output'] for op in if any(op in ['placeholder', 'get_attr', 'output'] for op in
[self.node_list[before_idx].op, self.node_list[after_idx].op]): [self.node_list[start_idx].op, self.node_list[end_idx].op]):
continue continue
if any(any(i in name for i in ['getitem', 'getattr']) for name in if any(any(i in name for i in ['getitem', 'getattr']) for name in
[self.node_list[before_idx].name, self.node_list[after_idx].name]): [self.node_list[start_idx].name, self.node_list[end_idx].name]):
continue continue
# select free dim # select free dim
before_trace = input_trace[before_idx] free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
after_trace = output_trace[after_idx] if len(free_dim) > 0:
free_dim = [] free_dim = [free_dim[0]]
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): chunk_info = [chunk_info[0]]
if (before_trace['idx'][i] == after_trace['idx'][i] and possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
self._not_compute(before_trace, (before_idx, after_idx), i) and
self._not_compute(after_trace, (before_idx, after_idx), i) and
self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1):
free_dim.append(i)
possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
return possible_chunk_region return possible_chunk_region
def _search_best_chunk_region(self, possible_chunk_regions): def _search_best_chunk_region(self, possible_chunk_regions):
@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
chunk_search = chunk_region_search.search_region() chunk_search = chunk_region_search.search_region()
chunk_regions = [i['region'] for i in chunk_search] chunk_regions = [i['region'] for i in chunk_search]
chunk_dims = [i['dim'] for i in chunk_search] chunk_dims = [i['dim'] for i in chunk_search]
chunk_infos = [i['chunk_info'] for i in chunk_search]
chunk_starts = [item[0] for item in chunk_regions] chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = [] chunk_inputs = [[j['inputs'][0] for j in i] for i in chunk_infos]
chunk_outputs = [] chunk_outputs = [[j['outputs'][0] for j in i] for i in chunk_infos]
within_chunk_region = False within_chunk_region = False
node_list = list(nodes) node_list = list(nodes)
# find the input and output var names for each offload region # find the input and output var names for each offload region
for idx, (start, end) in enumerate(chunk_regions): # for idx, (start, end) in enumerate(chunk_regions):
offload_node_list = node_list[start:end + 1] # offload_node_list = node_list[start:end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list) # inputs, outputs = _find_input_and_output_nodes(offload_node_list)
chunk_inputs.append(inputs) # chunk_inputs.append(inputs)
chunk_outputs.append(outputs) # chunk_outputs.append(outputs)
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
chunk_inputs_names = [] chunk_inputs_names = []