finish basic index tracer

pull/2364/head
oahzxl 2022-11-14 23:38:05 +08:00
parent 1607d04e81
commit c36dba07de
1 changed files with 124 additions and 9 deletions

View File

@ -25,6 +25,7 @@ class NodeIndexTracer(object):
self.nodes_list = list(gm.graph.nodes) self.nodes_list = list(gm.graph.nodes)
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
self.idx_trace_equal = [] self.idx_trace_equal = []
self.idx_view_list = []
self.idx_count = 1 self.idx_count = 1
def add_index(self): def add_index(self):
@ -35,7 +36,7 @@ class NodeIndexTracer(object):
_, compute_from = self.find_trace_from_node(node_from) _, compute_from = self.find_trace_from_node(node_from)
idx_to, compute_to = self.find_trace_from_node(node_to) idx_to, compute_to = self.find_trace_from_node(node_to)
for i in compute_from: for i in compute_from:
if i in idx_to: if i in idx_to and i not in compute_to:
compute_to.append(i) compute_to.append(i)
def mark_idx_equal(self, idx1, idx2): def mark_idx_equal(self, idx1, idx2):
@ -47,7 +48,8 @@ class NodeIndexTracer(object):
dim = [dim] dim = [dim]
for d in dim: for d in dim:
cur_idx = input_node_idx_trace[d] cur_idx = input_node_idx_trace[d]
self.idx_trace_list[idx]['compute'].append(cur_idx) if cur_idx not in self.idx_trace_list[idx]['compute']:
self.idx_trace_list[idx]['compute'].append(cur_idx)
def find_trace_from_node(self, node): def find_trace_from_node(self, node):
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
@ -56,8 +58,11 @@ class NodeIndexTracer(object):
def find_idx_trace_from_node(self, node): def find_idx_trace_from_node(self, node):
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_idx_trace = self.idx_trace_list[node_idx]['idx'] return self.idx_trace_list[node_idx]['idx']
return node_idx_trace
def find_compute_trace_from_node(self, node):
node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['compute']
def assign_index_as_input(self, node, node_idx): def assign_index_as_input(self, node, node_idx):
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
@ -82,6 +87,18 @@ class NodeIndexTracer(object):
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(node.args[0], node)
def assign_permute_index(self, node, node_idx):
permute_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
new_idx_trace = copy.deepcopy(input_node_idx_trace)
for idx, d in enumerate(permute_dim):
new_idx_trace[idx] = input_node_idx_trace[d]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(node.args[0], node)
def assign_linear_index(self, node, node_idx): def assign_linear_index(self, node, node_idx):
input_node, weight, bias = node.args input_node, weight, bias = node.args
@ -100,10 +117,99 @@ class NodeIndexTracer(object):
bias_idx_trace = self.find_idx_trace_from_node(bias) bias_idx_trace = self.find_idx_trace_from_node(bias)
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
def assign_matmul_index(self, node, node_idx):
matmul_left, matmul_right = node.args
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
new_idx_trace[-1] = matmul_right_idx_trace[-1]
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self.inherit_computation(matmul_left, node)
self.inherit_computation(matmul_right, node)
self.mark_computation(node, node_idx, [-1])
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
def assign_layernorm_index(self, node, idx): def assign_layernorm_index(self, node, idx):
self.assign_index_as_input(node, idx) self.assign_index_as_input(node, idx)
self.inherit_computation(node.args[0], node)
self.mark_computation(node, idx, [-1, -2]) self.mark_computation(node, idx, [-1, -2])
def assign_elementwise_index(self, node, idx):
self.assign_index_as_input(node, idx)
for node_in in node.args:
if type(node_in) not in (int, float):
self.inherit_computation(node_in, node)
def assign_softmax_index(self, node, idx):
self.assign_index_as_input(node, idx)
self.mark_computation(node, idx, [node.kwargs['dim']])
def assign_view_reshape_index(self, node, node_idx):
# get data, turn into number
origin_node = node.args[0]
origin_shape = origin_node.meta['tensor_meta'].shape
target_shape = []
for i in range(1, len(node.args)):
if isinstance(node.args[i], int):
target_shape.append(node.args[i])
else:
target_shape.append(node.args[i].meta['fwd_out'][0])
# compute the value of -1
if -1 in target_shape:
origin_product = 1
for i in origin_shape:
origin_product *= i
target_product = -1
for i in target_shape:
target_product *= i
shape_idx = target_shape.index(-1)
target_shape[shape_idx] = origin_product // target_product
# determine changed dim
len_diff = len(origin_shape) - len(target_shape)
if len_diff == 1:
# dim merge
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
dim_to = [dim_equal.index(False)]
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
elif len_diff == -1:
# dim expand
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
else:
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
# get new index
origin_trace = self.find_idx_trace_from_node(origin_node)
new_trace = copy.deepcopy(origin_trace)
dim_from.reverse()
for i in dim_from:
new_trace.pop(i)
for i in dim_to:
new_trace.insert(i, self.add_index())
self.idx_trace_list[node_idx]['idx'] = new_trace
# inherit computation
self.inherit_computation(origin_node, node)
compute_log = self.find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self.mark_computation(node, node_idx, [j])
break
# log view
view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from,
"idx_to": [new_trace[i] for i in dim_to],
"dim_to": dim_to}
self.idx_view_list.append(view_dict)
def trace_node_idx(self): def trace_node_idx(self):
for idx, node in enumerate(self.nodes_list): for idx, node in enumerate(self.nodes_list):
if node.op == 'placeholder': if node.op == 'placeholder':
@ -111,15 +217,21 @@ class NodeIndexTracer(object):
elif node.op == 'call_method': elif node.op == 'call_method':
if 'transpose' in node.name: if 'transpose' in node.name:
self.assign_transpose_index(node, idx) self.assign_transpose_index(node, idx)
elif 'view' in node.name:
pass
elif 'permute' in node.name: elif 'permute' in node.name:
pass self.assign_permute_index(node, idx)
elif 'view' in node.name or 'reshape' in node.name:
self.assign_view_reshape_index(node, idx)
else: else:
raise NotImplementedError(node.name, "method not implemented yet!") raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == 'call_function': elif node.op == 'call_function':
if 'linear' in node.name: if 'linear' in node.name:
self.assign_linear_index(node, idx) self.assign_linear_index(node, idx)
elif 'matmul' in node.name:
self.assign_matmul_index(node, idx)
elif 'softmax' in node.name:
self.assign_softmax_index(node, idx)
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
self.assign_elementwise_index(node, idx)
elif 'getattr' in node.name: elif 'getattr' in node.name:
continue # get attr like shape continue # get attr like shape
elif 'getitem' in node.name: elif 'getitem' in node.name:
@ -127,12 +239,14 @@ class NodeIndexTracer(object):
else: else:
raise NotImplementedError(node.name, "function not implemented yet!") raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == 'call_module': elif node.op == 'call_module':
if 'layernorm' in node.name: if any(n in node.name for n in ['layernorm', 'norm']):
self.assign_layernorm_index(node, idx) self.assign_layernorm_index(node, idx)
else: else:
raise NotImplementedError(node.name, "module not implemented yet!") raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == 'get_attr': elif node.op == 'get_attr':
self.assign_all_index(node, idx) # get param self.assign_all_index(node, idx) # get param
elif node.op == 'output':
continue
else: else:
raise NotImplementedError(node.op, "op not implemented yet!") raise NotImplementedError(node.op, "op not implemented yet!")
@ -297,6 +411,7 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# node is an operation, calculate tmp, output node and delete node memory # node is an operation, calculate tmp, output node and delete node memory
else: else:
# forward memory # forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
# record max act memory # record max act memory