mirror of https://github.com/hpcaitech/ColossalAI
add doc
parent
1be0ac3cbf
commit
7d4abaa525
|
@ -20,11 +20,22 @@ from .search_chunk import SearchChunk
|
||||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||||
|
|
||||||
|
|
||||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_dim (int)
|
||||||
|
chunk_indice_name (str): chunk indice name
|
||||||
|
shape (List): node shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new_shape (str): return slice
|
||||||
|
"""
|
||||||
new_shape = "["
|
new_shape = "["
|
||||||
for idx, i in enumerate(shape):
|
for idx, _ in enumerate(shape):
|
||||||
if idx == chunk_dim:
|
if idx == chunk_dim:
|
||||||
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
|
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
|
||||||
else:
|
else:
|
||||||
new_shape += ":"
|
new_shape += ":"
|
||||||
new_shape += ", "
|
new_shape += ", "
|
||||||
|
@ -32,7 +43,26 @@ def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||||
return new_shape
|
return new_shape
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
def _gen_loop_start(
|
||||||
|
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk loop start
|
||||||
|
|
||||||
|
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
|
||||||
|
chunk_size = 32
|
||||||
|
for chunk_idx in range(0, 100, 32):
|
||||||
|
......
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_input (List[Node]): chunk input node
|
||||||
|
chunk_output (Node): chunk output node
|
||||||
|
chunk_ouput_dim (int): chunk output node chunk dim
|
||||||
|
chunk_size (int): chunk size. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
context (str): generated str
|
||||||
|
"""
|
||||||
input_node = chunk_input[0]
|
input_node = chunk_input[0]
|
||||||
out_shape = get_node_shape(chunk_output)
|
out_shape = get_node_shape(chunk_output)
|
||||||
out_str = str(list(out_shape))
|
out_str = str(list(out_shape))
|
||||||
|
@ -45,8 +75,28 @@ def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_end(
|
def _gen_loop_end(
|
||||||
chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list
|
chunk_inputs: List[Node],
|
||||||
):
|
chunk_non_compute_inputs: List[Node],
|
||||||
|
chunk_outputs: Node,
|
||||||
|
chunk_outputs_dim: int,
|
||||||
|
node_list: List[Node],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk loop end
|
||||||
|
|
||||||
|
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
|
||||||
|
output_node = chunk_result; xx = None; xx = None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_inputs (List[Node]): chunk input node
|
||||||
|
chunk_non_compute_inputs (List[Node]): input node without chunk
|
||||||
|
chunk_outputs (Node): chunk output node
|
||||||
|
chunk_outputs_dim (int): chunk output node chunk dim
|
||||||
|
node_list (List)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
context (str): generated str
|
||||||
|
"""
|
||||||
chunk_outputs_name = chunk_outputs.name
|
chunk_outputs_name = chunk_outputs.name
|
||||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||||
|
@ -76,7 +126,10 @@ def _gen_loop_end(
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _replace_name(context, name_from, name_to):
|
def _replace_name(context: str, name_from: str, name_to: str) -> str:
|
||||||
|
"""
|
||||||
|
replace node name
|
||||||
|
"""
|
||||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||||
for p in patterns:
|
for p in patterns:
|
||||||
source = p[0] + name_from + p[1]
|
source = p[0] + name_from + p[1]
|
||||||
|
@ -86,7 +139,10 @@ def _replace_name(context, name_from, name_to):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _replace_reshape_size(context, node_name, reshape_size_dict):
|
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
|
||||||
|
"""
|
||||||
|
replace reshape size, some may have changed due to chunk
|
||||||
|
"""
|
||||||
if node_name not in reshape_size_dict:
|
if node_name not in reshape_size_dict:
|
||||||
return context
|
return context
|
||||||
for size_name, size_value in reshape_size_dict[node_name].items():
|
for size_name, size_value in reshape_size_dict[node_name].items():
|
||||||
|
@ -94,7 +150,17 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body):
|
def _replace_ones_like(
|
||||||
|
search_chunk: SearchChunk,
|
||||||
|
chunk_infos: List[Dict],
|
||||||
|
region_idx: int,
|
||||||
|
node_idx: int,
|
||||||
|
node: Node,
|
||||||
|
body: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
add chunk slice for new tensor op such as ones like
|
||||||
|
"""
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
|
@ -114,7 +180,16 @@ def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body):
|
def _replace_input_node(
|
||||||
|
chunk_inputs: List[Node],
|
||||||
|
region_idx: int,
|
||||||
|
chunk_inputs_dim: Dict,
|
||||||
|
node_idx: int,
|
||||||
|
body: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
add chunk slice for input nodes
|
||||||
|
"""
|
||||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||||
if idx == node_idx:
|
if idx == node_idx:
|
||||||
|
@ -193,7 +268,7 @@ def emit_code_with_chunk(
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
# replace input var with chunk var
|
# replace input var with chunk var
|
||||||
body = _replace_input_var(
|
body = _replace_input_node(
|
||||||
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||||
)
|
)
|
||||||
# ones like
|
# ones like
|
||||||
|
|
|
@ -15,6 +15,10 @@ from .utils import (
|
||||||
|
|
||||||
|
|
||||||
class EstimateMemory(object):
|
class EstimateMemory(object):
|
||||||
|
"""
|
||||||
|
Estimate memory with chunk
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -31,8 +35,6 @@ class EstimateMemory(object):
|
||||||
}
|
}
|
||||||
out_size = activation_size(fwd_out)
|
out_size = activation_size(fwd_out)
|
||||||
out_node = [n.name] if out_size > 0 else []
|
out_node = [n.name] if out_size > 0 else []
|
||||||
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
|
|
||||||
# out_size = 0
|
|
||||||
return out_size, out_node
|
return out_size, out_node
|
||||||
|
|
||||||
def _get_output_node_size(self, n):
|
def _get_output_node_size(self, n):
|
||||||
|
@ -184,10 +186,24 @@ class EstimateMemory(object):
|
||||||
|
|
||||||
def estimate_chunk_inference_mem(
|
def estimate_chunk_inference_mem(
|
||||||
self,
|
self,
|
||||||
node_list,
|
node_list: List,
|
||||||
chunk_infos=None,
|
chunk_infos=None,
|
||||||
print_mem=False,
|
print_mem=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Estimate inference memory with chunk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_list (List): _description_
|
||||||
|
chunk_infos (Dict): Chunk information. Defaults to None.
|
||||||
|
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
act_memory_peak_log (List): peak memory of every node
|
||||||
|
act_memory_after_node_log (List): memory after excuting every node
|
||||||
|
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||||
|
nodes generated but not deleted.
|
||||||
|
"""
|
||||||
act_memory = 0.0
|
act_memory = 0.0
|
||||||
act_memory_peak_log = []
|
act_memory_peak_log = []
|
||||||
act_memory_after_node_log = []
|
act_memory_after_node_log = []
|
||||||
|
|
|
@ -3,6 +3,10 @@ from .utils import find_idx_by_name
|
||||||
|
|
||||||
|
|
||||||
class ReorderGraph(object):
|
class ReorderGraph(object):
|
||||||
|
"""
|
||||||
|
Reorder node list and indice trace list
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||||
self.trace_indice = trace_indice
|
self.trace_indice = trace_indice
|
||||||
self.all_reorder_map = {
|
self.all_reorder_map = {
|
||||||
|
@ -60,7 +64,9 @@ class ReorderGraph(object):
|
||||||
|
|
||||||
def _reorder_idx_trace(self, reorder_map):
|
def _reorder_idx_trace(self, reorder_map):
|
||||||
# reorder list
|
# reorder list
|
||||||
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
|
new_idx_trace_list = [
|
||||||
|
None for _ in range(len(self.trace_indice.indice_trace_list))
|
||||||
|
]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||||
self.trace_indice.indice_trace_list = new_idx_trace_list
|
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||||
|
|
Loading…
Reference in New Issue