polish code

pull/2364/head
oahzxl 2022-11-15 10:46:51 +08:00
parent 7e2bd1e428
commit fad3b6d1a6
1 changed files with 237 additions and 237 deletions

View File

@ -10,6 +10,13 @@ CODEGEN_AVAILABLE = True
__all__ = ['ChunkCodeGen']
def _delete_free_var_from_last_use(user_to_last_uses):
for key, value in user_to_last_uses.items():
for n in value:
if n.op == 'placeholder':
user_to_last_uses[key].remove(n)
class NodeIndexTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
@ -19,7 +26,7 @@ class NodeIndexTracer(object):
self.idx_view_list = []
self.idx_count = -1
def add_index(self):
def _add_index(self):
"""
Update the count and return it. To record the idx number.
@ -29,7 +36,7 @@ class NodeIndexTracer(object):
self.idx_count += 1
return self.idx_count
def inherit_computation(self, node_from, node_to):
def _inherit_computation(self, node_from, node_to):
"""
Inherit computed dim from node_from to node_to.
If a dim in node_from is marked as computed and exists in node_to,
@ -39,13 +46,13 @@ class NodeIndexTracer(object):
node_from (node): node to be inherited
node_to (node): new node to inherit
"""
_, compute_from = self.find_trace_from_node(node_from)
idx_to, compute_to = self.find_trace_from_node(node_to)
_, compute_from = self._find_trace_from_node(node_from)
idx_to, compute_to = self._find_trace_from_node(node_to)
for i in compute_from:
if i in idx_to and i not in compute_to:
compute_to.append(i)
def mark_idx_equal(self, idx1, idx2):
def _mark_idx_equal(self, idx1, idx2):
"""
Mark 2 index to be equal.
@ -55,7 +62,7 @@ class NodeIndexTracer(object):
"""
self.idx_trace_equal.append((idx1, idx2))
def mark_computation(self, node, idx, dim):
def _mark_computation(self, node, idx, dim):
"""
Mark some dims of node as computed.
@ -64,7 +71,7 @@ class NodeIndexTracer(object):
idx (int): node index
dim (list or int): dims to be marked as computed
"""
input_node_idx_trace = self.find_idx_trace_from_node(node)
input_node_idx_trace = self._find_idx_trace_from_node(node)
if isinstance(dim, int):
dim = [dim]
for d in dim:
@ -72,7 +79,7 @@ class NodeIndexTracer(object):
if cur_idx not in self.idx_trace_list[idx]['compute']:
self.idx_trace_list[idx]['compute'].append(cur_idx)
def find_trace_from_node(self, node):
def _find_trace_from_node(self, node):
"""
Find node idx and compute trace by the node.
@ -86,7 +93,7 @@ class NodeIndexTracer(object):
node_dict = self.idx_trace_list[node_idx]
return node_dict['idx'], node_dict['compute']
def find_idx_trace_from_node(self, node):
def _find_idx_trace_from_node(self, node):
"""
Find node idx trace by the node.
@ -98,7 +105,7 @@ class NodeIndexTracer(object):
node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['idx']
def find_compute_trace_from_node(self, node):
def _find_compute_trace_from_node(self, node):
"""
Find node compute trace by the node.
@ -110,7 +117,7 @@ class NodeIndexTracer(object):
node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['compute']
def assign_index_as_input(self, node, node_idx):
def _assign_index_as_input(self, node, node_idx):
"""
Assign node's trace as its input node.
@ -124,7 +131,7 @@ class NodeIndexTracer(object):
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
def assign_all_index(self, node, node_idx):
def _assign_all_index(self, node, node_idx):
"""
Add new index for all node's dims.
@ -135,10 +142,10 @@ class NodeIndexTracer(object):
shape = node.meta['tensor_meta'].shape
new_trace = []
for _ in shape:
new_trace.append(self.add_index())
new_trace.append(self._add_index())
self.idx_trace_list[node_idx]['idx'] = new_trace
def assign_transpose_index(self, node, node_idx):
def _assign_transpose_index(self, node, node_idx):
"""
Assign index for transpose op.
1. swap input's dim according to transpose args
@ -149,16 +156,16 @@ class NodeIndexTracer(object):
node_idx (int)
"""
tranpose_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
new_idx_trace = copy.deepcopy(input_node_idx_trace)
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(node.args[0], node)
self._inherit_computation(node.args[0], node)
def assign_permute_index(self, node, node_idx):
def _assign_permute_index(self, node, node_idx):
"""
Assign index for permute op.
1. swap input's dim according to permute args
@ -169,16 +176,16 @@ class NodeIndexTracer(object):
node_idx (int)
"""
permute_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
new_idx_trace = copy.deepcopy(input_node_idx_trace)
for idx, d in enumerate(permute_dim):
new_idx_trace[idx] = input_node_idx_trace[d]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(node.args[0], node)
self._inherit_computation(node.args[0], node)
def assign_linear_index(self, node, node_idx):
def _assign_linear_index(self, node, node_idx):
"""
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
@ -190,22 +197,22 @@ class NodeIndexTracer(object):
node_idx (int)
"""
input_node, weight, bias = node.args
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
weight_idx_trace = self.find_idx_trace_from_node(weight)
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
weight_idx_trace = self._find_idx_trace_from_node(weight)
new_idx_trace = copy.deepcopy(input_node_idx_trace)
new_idx_trace[-1] = weight_idx_trace[1]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(input_node, node)
self.mark_computation(node, node_idx, [-1])
self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
self._inherit_computation(input_node, node)
self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
if bias:
bias_idx_trace = self.find_idx_trace_from_node(bias)
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
bias_idx_trace = self._find_idx_trace_from_node(bias)
self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
def assign_matmul_index(self, node, node_idx):
def _assign_matmul_index(self, node, node_idx):
"""
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
@ -217,20 +224,20 @@ class NodeIndexTracer(object):
node_idx (int)
"""
matmul_left, matmul_right = node.args
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left)
matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right)
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
new_idx_trace[-1] = matmul_right_idx_trace[-1]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(matmul_left, node)
self.inherit_computation(matmul_right, node)
self.mark_computation(node, node_idx, [-1])
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
self._inherit_computation(matmul_left, node)
self._inherit_computation(matmul_right, node)
self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
def assign_layernorm_index(self, node, idx):
def _assign_layernorm_index(self, node, idx):
"""
Assign index for layernorm op.
1. assign index as input node
@ -240,11 +247,11 @@ class NodeIndexTracer(object):
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
self.inherit_computation(node.args[0], node)
self.mark_computation(node, idx, [-1, -2])
self._assign_index_as_input(node, idx)
self._inherit_computation(node.args[0], node)
self._mark_computation(node, idx, [-1, -2])
def assign_elementwise_index(self, node, idx):
def _assign_elementwise_index(self, node, idx):
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
@ -254,12 +261,12 @@ class NodeIndexTracer(object):
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
self._assign_index_as_input(node, idx)
for node_in in node.args:
if type(node_in) not in (int, float):
self.inherit_computation(node_in, node)
self._inherit_computation(node_in, node)
def assign_softmax_index(self, node, idx):
def _assign_softmax_index(self, node, idx):
"""
Assign index for softmax op.
1. assign index as input node
@ -269,11 +276,11 @@ class NodeIndexTracer(object):
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
self.inherit_computation(node.args[0], node)
self.mark_computation(node, idx, [node.kwargs['dim']])
self._assign_index_as_input(node, idx)
self._inherit_computation(node.args[0], node)
self._mark_computation(node, idx, [node.kwargs['dim']])
def assign_view_reshape_index(self, node, node_idx):
def _assign_view_reshape_index(self, node, node_idx):
"""
Assign index for view and reshape op.
1. get origin shape and target shape by meta info.
@ -325,22 +332,22 @@ class NodeIndexTracer(object):
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
# get new index
origin_trace = self.find_idx_trace_from_node(origin_node)
origin_trace = self._find_idx_trace_from_node(origin_node)
new_trace = copy.deepcopy(origin_trace)
dim_from.reverse()
for i in dim_from:
new_trace.pop(i)
for i in dim_to:
new_trace.insert(i, self.add_index())
new_trace.insert(i, self._add_index())
self.idx_trace_list[node_idx]['idx'] = new_trace
# inherit computation
self.inherit_computation(origin_node, node)
compute_log = self.find_compute_trace_from_node(origin_node)
self._inherit_computation(origin_node, node)
compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self.mark_computation(node, node_idx, [j])
self._mark_computation(node, node_idx, [j])
break
# log view, not used now
@ -353,25 +360,25 @@ class NodeIndexTracer(object):
def trace_node_idx(self):
for idx, node in enumerate(self.nodes_list):
if node.op == 'placeholder':
self.assign_all_index(node, idx)
self._assign_all_index(node, idx)
elif node.op == 'call_method':
if 'transpose' in node.name:
self.assign_transpose_index(node, idx)
self._assign_transpose_index(node, idx)
elif 'permute' in node.name:
self.assign_permute_index(node, idx)
self._assign_permute_index(node, idx)
elif 'view' in node.name or 'reshape' in node.name:
self.assign_view_reshape_index(node, idx)
self._assign_view_reshape_index(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == 'call_function':
if 'linear' in node.name:
self.assign_linear_index(node, idx)
self._assign_linear_index(node, idx)
elif 'matmul' in node.name:
self.assign_matmul_index(node, idx)
self._assign_matmul_index(node, idx)
elif 'softmax' in node.name:
self.assign_softmax_index(node, idx)
self._assign_softmax_index(node, idx)
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
self.assign_elementwise_index(node, idx)
self._assign_elementwise_index(node, idx)
elif 'getattr' in node.name:
continue # get attr like shape
elif 'getitem' in node.name:
@ -380,39 +387,40 @@ class NodeIndexTracer(object):
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == 'call_module':
if any(n in node.name for n in ['layernorm', 'norm']):
self.assign_layernorm_index(node, idx)
self._assign_layernorm_index(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == 'get_attr':
self.assign_all_index(node, idx) # get param
self._assign_all_index(node, idx) # get param
elif node.op == 'output':
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
def _get_meta_node_size(x):
class MemoryEstimator(object):
def __init__(self) -> None:
pass
def _get_meta_node_size(self, x):
x = x.meta['tensor_meta']
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
return x
def _get_output_node_size(n):
def _get_output_node_size(self, n):
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(fwd_out)
def _get_delete_node_size(user, user_to_last_uses):
def _get_delete_node_size(self, user, user_to_last_uses):
if user.op in ('placeholder', 'output'):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete])
return delete_size
return 0
def _get_last_usr(nodes):
def _get_last_usr(self, nodes):
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
@ -426,15 +434,7 @@ def _get_last_usr(nodes):
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return user_to_last_uses
def _delete_free_var_from_last_use(user_to_last_uses):
for key, value in user_to_last_uses.items():
for n in value:
if n.op == 'placeholder':
user_to_last_uses[key].remove(n)
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ['transpose', 'permute']
@ -442,7 +442,7 @@ def _get_contiguous_memory(node, not_contiguous_list, delete=False):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += _get_output_node_size(n)
mem += self._get_output_node_size(n)
elif node.op == 'call_module':
for n in node.args:
if n in not_contiguous_list:
@ -458,18 +458,17 @@ def _get_contiguous_memory(node, not_contiguous_list, delete=False):
return mem
def _estimate_inference_mem(gm: torch.fx.GraphModule):
def estimate_inference_mem(self, gm: torch.fx.GraphModule):
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
for node in gm.graph.nodes:
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node) / (1024 ** 2)
act_memory += self._get_meta_node_size(node) / (1024 ** 2)
act_memory_peak_log.append(act_memory)
act_memory_after_node_log.append(act_memory)
# skip output
@ -478,30 +477,30 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
# node is an operation, calculate tmp, output node and delete node memory
else:
# forward memory
act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
act_memory += _get_output_node_size(node) / (1024 ** 2)
act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
act_memory += self._get_output_node_size(node) / (1024 ** 2)
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
act_memory_after_node_log.append(act_memory)
print("no chunk")
_print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
_print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
param_memory = parameter_size(gm)
return act_memory + param_memory, param_memory
def _get_chunk_ratio(node, chunk_dim, chunk_size):
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
shape = node.meta['tensor_meta'].shape
chunk_ratio = float(chunk_size) / shape[chunk_dim]
return chunk_ratio
def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
if user.op in ('placeholder', 'output'):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
@ -509,11 +508,11 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
for n in nodes_to_delete:
node_idx = _find_idx_by_name(n.name, node_list)
if start_node <= node_idx < end_node:
delete_size += _get_output_node_size(n) * chunk_ratio
delete_size += self._get_output_node_size(n) * chunk_ratio
return delete_size
def _print_mem_log(log, nodes, title=None):
def _print_mem_log(self, log, nodes, title=None):
if title:
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
@ -523,12 +522,12 @@ def _print_mem_log(log, nodes, title=None):
print("\n")
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
within_chunk = False
region_idx = 0
@ -539,12 +538,12 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if idx in start_nodes:
within_chunk = True
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory_peak_log.append(act_memory)
# skip output
elif node.op == 'output':
@ -553,21 +552,21 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
else:
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
if within_chunk:
act_memory -= _get_chunk_delete_node_size(
act_memory -= self._get_chunk_delete_node_size(
node, user_to_last_uses, chunk_ratio, node_list,
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
else:
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
if idx in end_nodes:
act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
within_chunk = False
chunk_ratio = 1
region_idx += 1
@ -575,8 +574,8 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
act_memory_after_node_log.append(act_memory)
print("chunk")
_print_mem_log(act_memory_peak_log, node_list, "peak")
_print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_mem_log(act_memory_peak_log, node_list, "peak")
self._print_mem_log(act_memory_after_node_log, node_list, "after")
param_memory = parameter_size(gm)
return act_memory + param_memory, param_memory
@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region = False
node_list = list(nodes)
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
_estimate_inference_mem(meta_graph)
memory_estimator = MemoryEstimator()
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
memory_estimator.estimate_inference_mem(meta_graph)
node_index_tracer = NodeIndexTracer(meta_graph)
node_index_tracer.trace_node_idx()