mirror of https://github.com/hpcaitech/ColossalAI
support output
parent
cda3e8572a
commit
de65e6c3e8
|
@ -56,6 +56,14 @@ def _is_non_compute_node_except_placeholder(node):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_compute_node_except_placeholder_output(node):
|
||||||
|
if any(i in node.op for i in ["get_attr"]) or any(
|
||||||
|
i in node.name for i in ["getitem", "getattr"]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FlowTracer(object):
|
class FlowTracer(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
|
@ -1083,13 +1091,14 @@ class MemoryEstimator(object):
|
||||||
i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]
|
i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]
|
||||||
)
|
)
|
||||||
chunk_within = False
|
chunk_within = False
|
||||||
chunk_region_idx = 0
|
chunk_region_idx = None
|
||||||
chunk_ratio = 1 # use it to estimate chunk mem
|
chunk_ratio = 1 # use it to estimate chunk mem
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||||
if use_chunk and idx in start_nodes:
|
if use_chunk and idx in start_nodes:
|
||||||
chunk_within = True
|
chunk_within = True
|
||||||
|
chunk_region_idx = start_nodes.index(idx)
|
||||||
chunk_ratio = self._get_chunk_ratio(
|
chunk_ratio = self._get_chunk_ratio(
|
||||||
node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]
|
node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]
|
||||||
)
|
)
|
||||||
|
@ -1149,7 +1158,7 @@ class MemoryEstimator(object):
|
||||||
)
|
)
|
||||||
chunk_within = False
|
chunk_within = False
|
||||||
chunk_ratio = 1
|
chunk_ratio = 1
|
||||||
chunk_region_idx += 1
|
chunk_region_idx = None
|
||||||
|
|
||||||
act_memory_after_node_log.append(act_memory)
|
act_memory_after_node_log.append(act_memory)
|
||||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||||
|
@ -1467,7 +1476,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||||
if (
|
if (
|
||||||
output_node not in nodes
|
output_node not in nodes
|
||||||
and node not in output_nodes
|
and node not in output_nodes
|
||||||
and not _is_non_compute_node_except_placeholder(output_node)
|
and not _is_non_compute_node_except_placeholder_output(output_node)
|
||||||
):
|
):
|
||||||
output_nodes.append(node)
|
output_nodes.append(node)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue