mirror of https://github.com/hpcaitech/ColossalAI
support index dupilictae and update loop
parent
1e0fd11bc1
commit
cda3e8572a
109
chunk_codegen.py
109
chunk_codegen.py
|
@ -180,7 +180,7 @@ class FlowTracer(object):
|
||||||
"args": {},
|
"args": {},
|
||||||
}
|
}
|
||||||
flow_block = False
|
flow_block = False
|
||||||
|
|
||||||
# TODO don't allow multi outputs now
|
# TODO don't allow multi outputs now
|
||||||
if len(outputs) > 1:
|
if len(outputs) > 1:
|
||||||
flow_block = True
|
flow_block = True
|
||||||
|
@ -200,7 +200,7 @@ class FlowTracer(object):
|
||||||
main_flow_var = i
|
main_flow_var = i
|
||||||
# if mix flow is a broadcast in chunk dim,
|
# if mix flow is a broadcast in chunk dim,
|
||||||
# TODO: need to move that flow out of the chunk
|
# TODO: need to move that flow out of the chunk
|
||||||
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
||||||
self.node_list[end_idx], end_dim, node
|
self.node_list[end_idx], end_dim, node
|
||||||
)
|
)
|
||||||
if mix_flow_node_dim is None:
|
if mix_flow_node_dim is None:
|
||||||
|
@ -223,7 +223,7 @@ class FlowTracer(object):
|
||||||
if flow_block:
|
if flow_block:
|
||||||
flow_block = True
|
flow_block = True
|
||||||
return flow_block, chunk_info
|
return flow_block, chunk_info
|
||||||
|
|
||||||
inputs_dim = []
|
inputs_dim = []
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
for input_node in chunk_info["inputs"]:
|
for input_node in chunk_info["inputs"]:
|
||||||
|
@ -234,7 +234,7 @@ class FlowTracer(object):
|
||||||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||||
dim = None
|
dim = None
|
||||||
if start_dim <= user_idx < end_idx:
|
if start_dim <= user_idx < end_idx:
|
||||||
dim = index_tracer._get_node_chunk_dim(
|
dim = index_tracer.get_node_chunk_dim(
|
||||||
self.node_list[end_idx], end_dim, input_node
|
self.node_list[end_idx], end_dim, input_node
|
||||||
)
|
)
|
||||||
elif user_idx == end_idx:
|
elif user_idx == end_idx:
|
||||||
|
@ -300,10 +300,10 @@ class IndexTracer(object):
|
||||||
self.idx_trace_list[idx]["compute"].pop(dim_idx)
|
self.idx_trace_list[idx]["compute"].pop(dim_idx)
|
||||||
self.idx_trace_list[idx]["source"].pop(dim_idx)
|
self.idx_trace_list[idx]["source"].pop(dim_idx)
|
||||||
|
|
||||||
def _add_dim(self, idx, dim_idx):
|
def _add_dim(self, node_idx, dim_idx):
|
||||||
self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index())
|
self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index())
|
||||||
self.idx_trace_list[idx]["compute"].insert(dim_idx, [])
|
self.idx_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||||
self.idx_trace_list[idx]["source"].insert(dim_idx, {})
|
self.idx_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||||
|
|
||||||
def _transform_index(self, node, node_dim):
|
def _transform_index(self, node, node_dim):
|
||||||
node_idx = self._find_idx_trace_from_node(node)
|
node_idx = self._find_idx_trace_from_node(node)
|
||||||
|
@ -659,9 +659,7 @@ class IndexTracer(object):
|
||||||
"""
|
"""
|
||||||
self._del_dim(node_idx, -1)
|
self._del_dim(node_idx, -1)
|
||||||
self._assign_index_as_input(node, node_idx)
|
self._assign_index_as_input(node, node_idx)
|
||||||
self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index())
|
self._add_dim(node_idx, node.args[1])
|
||||||
self.idx_trace_list[node_idx]["compute"].insert(node.args[1], [])
|
|
||||||
self.idx_trace_list[node_idx]["source"].insert(node.args[1], [])
|
|
||||||
|
|
||||||
def _assign_dropout_index(self, node, node_idx):
|
def _assign_dropout_index(self, node, node_idx):
|
||||||
"""
|
"""
|
||||||
|
@ -879,7 +877,7 @@ class IndexTracer(object):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
node_from_source = self._find_source_trace_from_node(node_from)
|
node_from_source = self._find_source_trace_from_node(node_from)
|
||||||
dim_source = node_from_source[node_from_dim]
|
dim_source = node_from_source[node_from_dim]
|
||||||
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
||||||
|
@ -888,6 +886,44 @@ class IndexTracer(object):
|
||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||||
|
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
||||||
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||||
|
node_trace_source = self._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(_get_node_shape(node))):
|
||||||
|
if (
|
||||||
|
input_node_idx in node_trace_source[node_dim]
|
||||||
|
and node_trace_source[node_dim][input_node_idx] == input_dim
|
||||||
|
):
|
||||||
|
return {node_idx: node_dim}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_index_duplicate(self, chunk_infos):
|
||||||
|
input_dim_after_node = {}
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||||
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||||
|
input_dim_after_node.update(
|
||||||
|
self._find_inherit_dim(input_node, v, self.nodes_list[k])
|
||||||
|
)
|
||||||
|
|
||||||
|
for node in self.nodes_list[
|
||||||
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||||
|
]:
|
||||||
|
if _is_non_compute_node_except_placeholder(node):
|
||||||
|
continue
|
||||||
|
count = 0
|
||||||
|
node_trace_source = self._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(_get_node_shape(node))):
|
||||||
|
dim_source = node_trace_source[node_dim]
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||||
|
if k in input_dim_after_node and input_dim_after_node[k] == v:
|
||||||
|
count += 1
|
||||||
|
break
|
||||||
|
if count > 1:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class MemoryEstimator(object):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -1160,7 +1196,7 @@ class ChunkRegionSearch(object):
|
||||||
min_len = len(n)
|
min_len = len(n)
|
||||||
return min_len
|
return min_len
|
||||||
|
|
||||||
def _search_max_chunk_region(self, active_node, peak_node):
|
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
||||||
free_vars = self._get_free_var()
|
free_vars = self._get_free_var()
|
||||||
min_var = self._get_min_free_var(active_node, free_vars)
|
min_var = self._get_min_free_var(active_node, free_vars)
|
||||||
|
|
||||||
|
@ -1180,6 +1216,21 @@ class ChunkRegionSearch(object):
|
||||||
break
|
break
|
||||||
if i in free_vars or i == 0:
|
if i in free_vars or i == 0:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
for i in chunk_regions:
|
||||||
|
region = i["region"]
|
||||||
|
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||||
|
return None
|
||||||
|
elif (
|
||||||
|
region[0] <= chunk_region_start <= region[1]
|
||||||
|
and chunk_region_end > region[1]
|
||||||
|
):
|
||||||
|
chunk_region_start = region[1] + 1
|
||||||
|
elif (
|
||||||
|
region[0] <= chunk_region_end <= region[1]
|
||||||
|
and chunk_region_start < region[0]
|
||||||
|
):
|
||||||
|
chunk_region_end = region[0] - 1
|
||||||
return chunk_region_start, chunk_region_end
|
return chunk_region_start, chunk_region_end
|
||||||
|
|
||||||
def _is_not_compute(self, trace, chunk_range, dim_idx):
|
def _is_not_compute(self, trace, chunk_range, dim_idx):
|
||||||
|
@ -1192,24 +1243,6 @@ class ChunkRegionSearch(object):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_duplicate_map(self, chunk_infos):
|
|
||||||
dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos]
|
|
||||||
remove_list = []
|
|
||||||
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
|
|
||||||
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
|
|
||||||
if idx1 == idx2:
|
|
||||||
continue
|
|
||||||
# it means an index create 2 copy of itself
|
|
||||||
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
|
||||||
# TODO: currently remove it, deal with this in future
|
|
||||||
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
|
||||||
remove_list.append(chunk_infos[idx1])
|
|
||||||
remove_list.append(chunk_infos[idx2])
|
|
||||||
for i in remove_list:
|
|
||||||
if i in chunk_infos:
|
|
||||||
chunk_infos.remove(i)
|
|
||||||
return chunk_infos
|
|
||||||
|
|
||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
|
@ -1246,8 +1279,10 @@ class ChunkRegionSearch(object):
|
||||||
)
|
)
|
||||||
if flow_block:
|
if flow_block:
|
||||||
continue
|
continue
|
||||||
|
# check index copmute
|
||||||
|
if not self.index_tracer.check_index_duplicate(chunk_info):
|
||||||
|
continue
|
||||||
chunk_infos.append(chunk_info)
|
chunk_infos.append(chunk_info)
|
||||||
chunk_infos = self._check_duplicate_map(chunk_infos)
|
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||||
|
@ -1288,9 +1323,13 @@ class ChunkRegionSearch(object):
|
||||||
max_region_range = i["region"][1] - i["region"][0]
|
max_region_range = i["region"][1] - i["region"][0]
|
||||||
return best_regions
|
return best_regions
|
||||||
|
|
||||||
def _step_search(self, mem_peak, active_node):
|
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
max_chunk_region = self._search_max_chunk_region(
|
||||||
|
active_node, peak_node, chunk_regions
|
||||||
|
)
|
||||||
|
if max_chunk_region == None:
|
||||||
|
return None
|
||||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
|
@ -1313,7 +1352,7 @@ class ChunkRegionSearch(object):
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk_region = self._step_search(mem_peak, active_node)
|
chunk_region = self._step_search(mem_peak, active_node, chunk_regions)
|
||||||
if chunk_region is None:
|
if chunk_region is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -46,8 +46,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||||
non_fx_out = model(node, pair)
|
non_fx_out = model(node, pair)
|
||||||
fx_out = gm(node, pair)
|
fx_out = gm(node, pair)
|
||||||
|
|
||||||
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output"
|
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0]))
|
||||||
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output"
|
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1]))
|
||||||
|
|
||||||
# test barckward
|
# test barckward
|
||||||
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
||||||
|
|
Loading…
Reference in New Issue