mirror of https://github.com/hpcaitech/ColossalAI
finish basic index tracer
parent
1607d04e81
commit
c36dba07de
133
chunk_codegen.py
133
chunk_codegen.py
|
@ -25,6 +25,7 @@ class NodeIndexTracer(object):
|
|||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
self.idx_count = 1
|
||||
|
||||
def add_index(self):
|
||||
|
@ -35,7 +36,7 @@ class NodeIndexTracer(object):
|
|||
_, compute_from = self.find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self.find_trace_from_node(node_to)
|
||||
for i in compute_from:
|
||||
if i in idx_to:
|
||||
if i in idx_to and i not in compute_to:
|
||||
compute_to.append(i)
|
||||
|
||||
def mark_idx_equal(self, idx1, idx2):
|
||||
|
@ -47,7 +48,8 @@ class NodeIndexTracer(object):
|
|||
dim = [dim]
|
||||
for d in dim:
|
||||
cur_idx = input_node_idx_trace[d]
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
if cur_idx not in self.idx_trace_list[idx]['compute']:
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
|
||||
def find_trace_from_node(self, node):
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
|
@ -56,8 +58,11 @@ class NodeIndexTracer(object):
|
|||
|
||||
def find_idx_trace_from_node(self, node):
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx_trace = self.idx_trace_list[node_idx]['idx']
|
||||
return node_idx_trace
|
||||
return self.idx_trace_list[node_idx]['idx']
|
||||
|
||||
def find_compute_trace_from_node(self, node):
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['compute']
|
||||
|
||||
def assign_index_as_input(self, node, node_idx):
|
||||
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
||||
|
@ -82,6 +87,18 @@ class NodeIndexTracer(object):
|
|||
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self.inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_permute_index(self, node, node_idx):
|
||||
permute_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
new_idx_trace[idx] = input_node_idx_trace[d]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self.inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_linear_index(self, node, node_idx):
|
||||
input_node, weight, bias = node.args
|
||||
|
@ -100,10 +117,99 @@ class NodeIndexTracer(object):
|
|||
bias_idx_trace = self.find_idx_trace_from_node(bias)
|
||||
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
|
||||
def assign_matmul_index(self, node, node_idx):
|
||||
matmul_left, matmul_right = node.args
|
||||
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
||||
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
||||
|
||||
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
|
||||
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
|
||||
new_idx_trace[-1] = matmul_right_idx_trace[-1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
self.inherit_computation(matmul_left, node)
|
||||
self.inherit_computation(matmul_right, node)
|
||||
self.mark_computation(node, node_idx, [-1])
|
||||
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||
|
||||
def assign_layernorm_index(self, node, idx):
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self.mark_computation(node, idx, [-1, -2])
|
||||
|
||||
|
||||
def assign_elementwise_index(self, node, idx):
|
||||
self.assign_index_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) not in (int, float):
|
||||
self.inherit_computation(node_in, node)
|
||||
|
||||
def assign_softmax_index(self, node, idx):
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.mark_computation(node, idx, [node.kwargs['dim']])
|
||||
|
||||
def assign_view_reshape_index(self, node, node_idx):
|
||||
# get data, turn into number
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta['tensor_meta'].shape
|
||||
target_shape = []
|
||||
for i in range(1, len(node.args)):
|
||||
if isinstance(node.args[i], int):
|
||||
target_shape.append(node.args[i])
|
||||
else:
|
||||
target_shape.append(node.args[i].meta['fwd_out'][0])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
origin_product = 1
|
||||
for i in origin_shape:
|
||||
origin_product *= i
|
||||
target_product = -1
|
||||
for i in target_shape:
|
||||
target_product *= i
|
||||
shape_idx = target_shape.index(-1)
|
||||
target_shape[shape_idx] = origin_product // target_product
|
||||
|
||||
# determine changed dim
|
||||
len_diff = len(origin_shape) - len(target_shape)
|
||||
if len_diff == 1:
|
||||
# dim merge
|
||||
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||
dim_to = [dim_equal.index(False)]
|
||||
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
elif len_diff == -1:
|
||||
# dim expand
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
else:
|
||||
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new index
|
||||
origin_trace = self.find_idx_trace_from_node(origin_node)
|
||||
new_trace = copy.deepcopy(origin_trace)
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
new_trace.pop(i)
|
||||
for i in dim_to:
|
||||
new_trace.insert(i, self.add_index())
|
||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
|
||||
# inherit computation
|
||||
self.inherit_computation(origin_node, node)
|
||||
compute_log = self.find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self.mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view
|
||||
view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
|
||||
"dim_from": dim_from,
|
||||
"idx_to": [new_trace[i] for i in dim_to],
|
||||
"dim_to": dim_to}
|
||||
self.idx_view_list.append(view_dict)
|
||||
|
||||
def trace_node_idx(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
if node.op == 'placeholder':
|
||||
|
@ -111,15 +217,21 @@ class NodeIndexTracer(object):
|
|||
elif node.op == 'call_method':
|
||||
if 'transpose' in node.name:
|
||||
self.assign_transpose_index(node, idx)
|
||||
elif 'view' in node.name:
|
||||
pass
|
||||
elif 'permute' in node.name:
|
||||
pass
|
||||
self.assign_permute_index(node, idx)
|
||||
elif 'view' in node.name or 'reshape' in node.name:
|
||||
self.assign_view_reshape_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == 'call_function':
|
||||
if 'linear' in node.name:
|
||||
self.assign_linear_index(node, idx)
|
||||
elif 'matmul' in node.name:
|
||||
self.assign_matmul_index(node, idx)
|
||||
elif 'softmax' in node.name:
|
||||
self.assign_softmax_index(node, idx)
|
||||
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
||||
self.assign_elementwise_index(node, idx)
|
||||
elif 'getattr' in node.name:
|
||||
continue # get attr like shape
|
||||
elif 'getitem' in node.name:
|
||||
|
@ -127,12 +239,14 @@ class NodeIndexTracer(object):
|
|||
else:
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == 'call_module':
|
||||
if 'layernorm' in node.name:
|
||||
if any(n in node.name for n in ['layernorm', 'norm']):
|
||||
self.assign_layernorm_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == 'get_attr':
|
||||
self.assign_all_index(node, idx) # get param
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
|
@ -297,6 +411,7 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
# TODO: permute will create a tmp copy if not contiguous
|
||||
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
# record max act memory
|
||||
|
|
Loading…
Reference in New Issue