support index dupilictae and update loop

pull/2364/head
oahzxl 2022-12-13 10:02:26 +08:00
parent 1e0fd11bc1
commit cda3e8572a
2 changed files with 76 additions and 37 deletions

View File

@ -180,7 +180,7 @@ class FlowTracer(object):
"args": {},
}
flow_block = False
# TODO don't allow multi outputs now
if len(outputs) > 1:
flow_block = True
@ -200,7 +200,7 @@ class FlowTracer(object):
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO: need to move that flow out of the chunk
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
self.node_list[end_idx], end_dim, node
)
if mix_flow_node_dim is None:
@ -223,7 +223,7 @@ class FlowTracer(object):
if flow_block:
flow_block = True
return flow_block, chunk_info
inputs_dim = []
remove_inputs = []
for input_node in chunk_info["inputs"]:
@ -234,7 +234,7 @@ class FlowTracer(object):
user_idx = _find_idx_by_name(user.name, self.node_list)
dim = None
if start_dim <= user_idx < end_idx:
dim = index_tracer._get_node_chunk_dim(
dim = index_tracer.get_node_chunk_dim(
self.node_list[end_idx], end_dim, input_node
)
elif user_idx == end_idx:
@ -300,10 +300,10 @@ class IndexTracer(object):
self.idx_trace_list[idx]["compute"].pop(dim_idx)
self.idx_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, idx, dim_idx):
self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index())
self.idx_trace_list[idx]["compute"].insert(dim_idx, [])
self.idx_trace_list[idx]["source"].insert(dim_idx, {})
def _add_dim(self, node_idx, dim_idx):
self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index())
self.idx_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.idx_trace_list[node_idx]["source"].insert(dim_idx, {})
def _transform_index(self, node, node_dim):
node_idx = self._find_idx_trace_from_node(node)
@ -659,9 +659,7 @@ class IndexTracer(object):
"""
self._del_dim(node_idx, -1)
self._assign_index_as_input(node, node_idx)
self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index())
self.idx_trace_list[node_idx]["compute"].insert(node.args[1], [])
self.idx_trace_list[node_idx]["source"].insert(node.args[1], [])
self._add_dim(node_idx, node.args[1])
def _assign_dropout_index(self, node, node_idx):
"""
@ -879,7 +877,7 @@ class IndexTracer(object):
return False
return True
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
@ -888,6 +886,44 @@ class IndexTracer(object):
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_trace_source = self._find_source_trace_from_node(node)
for node_dim in range(len(_get_node_shape(node))):
if (
input_node_idx in node_trace_source[node_dim]
and node_trace_source[node_dim][input_node_idx] == input_dim
):
return {node_idx: node_dim}
return {}
def check_index_duplicate(self, chunk_infos):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
input_dim_after_node.update(
self._find_inherit_dim(input_node, v, self.nodes_list[k])
)
for node in self.nodes_list[
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]:
if _is_non_compute_node_except_placeholder(node):
continue
count = 0
node_trace_source = self._find_source_trace_from_node(node)
for node_dim in range(len(_get_node_shape(node))):
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] == v:
count += 1
break
if count > 1:
return False
return True
class MemoryEstimator(object):
def __init__(self) -> None:
@ -1160,7 +1196,7 @@ class ChunkRegionSearch(object):
min_len = len(n)
return min_len
def _search_max_chunk_region(self, active_node, peak_node):
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
free_vars = self._get_free_var()
min_var = self._get_min_free_var(active_node, free_vars)
@ -1180,6 +1216,21 @@ class ChunkRegionSearch(object):
break
if i in free_vars or i == 0:
raise RuntimeError()
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (
region[0] <= chunk_region_start <= region[1]
and chunk_region_end > region[1]
):
chunk_region_start = region[1] + 1
elif (
region[0] <= chunk_region_end <= region[1]
and chunk_region_start < region[0]
):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _is_not_compute(self, trace, chunk_range, dim_idx):
@ -1192,24 +1243,6 @@ class ChunkRegionSearch(object):
return True
return False
def _check_duplicate_map(self, chunk_infos):
dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos]
remove_list = []
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
if idx1 == idx2:
continue
# it means an index create 2 copy of itself
# eg. a = torch.matmul(x, x.transpose(-1, -2))
# TODO: currently remove it, deal with this in future
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
remove_list.append(chunk_infos[idx1])
remove_list.append(chunk_infos[idx2])
for i in remove_list:
if i in chunk_infos:
chunk_infos.remove(i)
return chunk_infos
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
@ -1246,8 +1279,10 @@ class ChunkRegionSearch(object):
)
if flow_block:
continue
# check index copmute
if not self.index_tracer.check_index_duplicate(chunk_info):
continue
chunk_infos.append(chunk_info)
chunk_infos = self._check_duplicate_map(chunk_infos)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
@ -1288,9 +1323,13 @@ class ChunkRegionSearch(object):
max_region_range = i["region"][1] - i["region"][0]
return best_regions
def _step_search(self, mem_peak, active_node):
def _step_search(self, mem_peak, active_node, chunk_regions):
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
max_chunk_region = self._search_max_chunk_region(
active_node, peak_node, chunk_regions
)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
@ -1313,7 +1352,7 @@ class ChunkRegionSearch(object):
mem_peak = init_mem_peak
while True:
chunk_region = self._step_search(mem_peak, active_node)
chunk_region = self._step_search(mem_peak, active_node, chunk_regions)
if chunk_region is None:
break

View File

@ -46,8 +46,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
non_fx_out = model(node, pair)
fx_out = gm(node, pair)
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output"
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output"
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1]))
# test barckward
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()