mirror of https://github.com/hpcaitech/ColossalAI
support index dupilictae and update loop
parent
1e0fd11bc1
commit
cda3e8572a
109
chunk_codegen.py
109
chunk_codegen.py
|
@ -180,7 +180,7 @@ class FlowTracer(object):
|
|||
"args": {},
|
||||
}
|
||||
flow_block = False
|
||||
|
||||
|
||||
# TODO don't allow multi outputs now
|
||||
if len(outputs) > 1:
|
||||
flow_block = True
|
||||
|
@ -200,7 +200,7 @@ class FlowTracer(object):
|
|||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO: need to move that flow out of the chunk
|
||||
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
||||
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, node
|
||||
)
|
||||
if mix_flow_node_dim is None:
|
||||
|
@ -223,7 +223,7 @@ class FlowTracer(object):
|
|||
if flow_block:
|
||||
flow_block = True
|
||||
return flow_block, chunk_info
|
||||
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in chunk_info["inputs"]:
|
||||
|
@ -234,7 +234,7 @@ class FlowTracer(object):
|
|||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||
dim = None
|
||||
if start_dim <= user_idx < end_idx:
|
||||
dim = index_tracer._get_node_chunk_dim(
|
||||
dim = index_tracer.get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, input_node
|
||||
)
|
||||
elif user_idx == end_idx:
|
||||
|
@ -300,10 +300,10 @@ class IndexTracer(object):
|
|||
self.idx_trace_list[idx]["compute"].pop(dim_idx)
|
||||
self.idx_trace_list[idx]["source"].pop(dim_idx)
|
||||
|
||||
def _add_dim(self, idx, dim_idx):
|
||||
self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index())
|
||||
self.idx_trace_list[idx]["compute"].insert(dim_idx, [])
|
||||
self.idx_trace_list[idx]["source"].insert(dim_idx, {})
|
||||
def _add_dim(self, node_idx, dim_idx):
|
||||
self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index())
|
||||
self.idx_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||
self.idx_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||
|
||||
def _transform_index(self, node, node_dim):
|
||||
node_idx = self._find_idx_trace_from_node(node)
|
||||
|
@ -659,9 +659,7 @@ class IndexTracer(object):
|
|||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index())
|
||||
self.idx_trace_list[node_idx]["compute"].insert(node.args[1], [])
|
||||
self.idx_trace_list[node_idx]["source"].insert(node.args[1], [])
|
||||
self._add_dim(node_idx, node.args[1])
|
||||
|
||||
def _assign_dropout_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -879,7 +877,7 @@ class IndexTracer(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
||||
|
@ -888,6 +886,44 @@ class IndexTracer(object):
|
|||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_trace_source = self._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(_get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
and node_trace_source[node_dim][input_node_idx] == input_dim
|
||||
):
|
||||
return {node_idx: node_dim}
|
||||
return {}
|
||||
|
||||
def check_index_duplicate(self, chunk_infos):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
input_dim_after_node.update(
|
||||
self._find_inherit_dim(input_node, v, self.nodes_list[k])
|
||||
)
|
||||
|
||||
for node in self.nodes_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
if _is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
node_trace_source = self._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(_get_node_shape(node))):
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] == v:
|
||||
count += 1
|
||||
break
|
||||
if count > 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self) -> None:
|
||||
|
@ -1160,7 +1196,7 @@ class ChunkRegionSearch(object):
|
|||
min_len = len(n)
|
||||
return min_len
|
||||
|
||||
def _search_max_chunk_region(self, active_node, peak_node):
|
||||
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
||||
free_vars = self._get_free_var()
|
||||
min_var = self._get_min_free_var(active_node, free_vars)
|
||||
|
||||
|
@ -1180,6 +1216,21 @@ class ChunkRegionSearch(object):
|
|||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (
|
||||
region[0] <= chunk_region_start <= region[1]
|
||||
and chunk_region_end > region[1]
|
||||
):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (
|
||||
region[0] <= chunk_region_end <= region[1]
|
||||
and chunk_region_start < region[0]
|
||||
):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _is_not_compute(self, trace, chunk_range, dim_idx):
|
||||
|
@ -1192,24 +1243,6 @@ class ChunkRegionSearch(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _check_duplicate_map(self, chunk_infos):
|
||||
dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos]
|
||||
remove_list = []
|
||||
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
|
||||
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
|
||||
if idx1 == idx2:
|
||||
continue
|
||||
# it means an index create 2 copy of itself
|
||||
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
||||
# TODO: currently remove it, deal with this in future
|
||||
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
||||
remove_list.append(chunk_infos[idx1])
|
||||
remove_list.append(chunk_infos[idx2])
|
||||
for i in remove_list:
|
||||
if i in chunk_infos:
|
||||
chunk_infos.remove(i)
|
||||
return chunk_infos
|
||||
|
||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
|
@ -1246,8 +1279,10 @@ class ChunkRegionSearch(object):
|
|||
)
|
||||
if flow_block:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.index_tracer.check_index_duplicate(chunk_info):
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
chunk_infos = self._check_duplicate_map(chunk_infos)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
|
@ -1288,9 +1323,13 @@ class ChunkRegionSearch(object):
|
|||
max_region_range = i["region"][1] - i["region"][0]
|
||||
return best_regions
|
||||
|
||||
def _step_search(self, mem_peak, active_node):
|
||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
active_node, peak_node, chunk_regions
|
||||
)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
|
@ -1313,7 +1352,7 @@ class ChunkRegionSearch(object):
|
|||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
chunk_region = self._step_search(mem_peak, active_node)
|
||||
chunk_region = self._step_search(mem_peak, active_node, chunk_regions)
|
||||
if chunk_region is None:
|
||||
break
|
||||
|
||||
|
|
|
@ -46,8 +46,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
|||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output"
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output"
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
# test barckward
|
||||
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
||||
|
|
Loading…
Reference in New Issue