add part of index tracer

pull/2364/head
oahzxl 2022-11-14 16:02:47 +08:00
parent d7634af5c0
commit 1607d04e81
1 changed files with 119 additions and 0 deletions

View File

@ -19,6 +19,123 @@ else:
__all__ = ['python_code_with_activation_checkpoint']
class NodeIndexTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.nodes_list = list(gm.graph.nodes)
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
self.idx_trace_equal = []
self.idx_count = 1
def add_index(self):
self.idx_count += 1
return self.idx_count - 1
def inherit_computation(self, node_from, node_to):
_, compute_from = self.find_trace_from_node(node_from)
idx_to, compute_to = self.find_trace_from_node(node_to)
for i in compute_from:
if i in idx_to:
compute_to.append(i)
def mark_idx_equal(self, idx1, idx2):
self.idx_trace_equal.append((idx1, idx2))
def mark_computation(self, node, idx, dim):
input_node_idx_trace = self.find_idx_trace_from_node(node)
if isinstance(dim, int):
dim = [dim]
for d in dim:
cur_idx = input_node_idx_trace[d]
self.idx_trace_list[idx]['compute'].append(cur_idx)
def find_trace_from_node(self, node):
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_dict = self.idx_trace_list[node_idx]
return node_dict['idx'], node_dict['compute']
def find_idx_trace_from_node(self, node):
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx_trace = self.idx_trace_list[node_idx]['idx']
return node_idx_trace
def assign_index_as_input(self, node, node_idx):
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
def assign_all_index(self, node, node_idx):
shape = node.meta['tensor_meta'].shape
new_trace = []
for _ in shape:
new_trace.append(self.add_index())
self.idx_trace_list[node_idx]['idx'] = new_trace
def assign_transpose_index(self, node, node_idx):
tranpose_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
new_idx_trace = copy.deepcopy(input_node_idx_trace)
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
def assign_linear_index(self, node, node_idx):
input_node, weight, bias = node.args
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
weight_idx_trace = self.find_idx_trace_from_node(weight)
new_idx_trace = copy.deepcopy(input_node_idx_trace)
new_idx_trace[-1] = weight_idx_trace[1]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(input_node, node)
self.mark_computation(node, node_idx, [-1])
self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
if bias:
bias_idx_trace = self.find_idx_trace_from_node(bias)
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
def assign_layernorm_index(self, node, idx):
self.assign_index_as_input(node, idx)
self.mark_computation(node, idx, [-1, -2])
def trace_node_idx(self):
for idx, node in enumerate(self.nodes_list):
if node.op == 'placeholder':
self.assign_all_index(node, idx)
elif node.op == 'call_method':
if 'transpose' in node.name:
self.assign_transpose_index(node, idx)
elif 'view' in node.name:
pass
elif 'permute' in node.name:
pass
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == 'call_function':
if 'linear' in node.name:
self.assign_linear_index(node, idx)
elif 'getattr' in node.name:
continue # get attr like shape
elif 'getitem' in node.name:
continue # get item in list
else:
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == 'call_module':
if 'layernorm' in node.name:
self.assign_layernorm_index(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == 'get_attr':
self.assign_all_index(node, idx) # get param
else:
raise NotImplementedError(node.op, "op not implemented yet!")
def _get_meta_node_size(x):
x = x.meta['tensor_meta']
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
@ -557,6 +674,8 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node_list = list(nodes)
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
_estimate_inference_mem(meta_graph)
node_index_tracer = NodeIndexTracer(meta_graph)
node_index_tracer.trace_node_idx()
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(chunk_regions):