mirror of https://github.com/hpcaitech/ColossalAI
add part of index tracer
parent
d7634af5c0
commit
1607d04e81
119
chunk_codegen.py
119
chunk_codegen.py
|
@ -19,6 +19,123 @@ else:
|
|||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
|
||||
|
||||
class NodeIndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_equal = []
|
||||
self.idx_count = 1
|
||||
|
||||
def add_index(self):
|
||||
self.idx_count += 1
|
||||
return self.idx_count - 1
|
||||
|
||||
def inherit_computation(self, node_from, node_to):
|
||||
_, compute_from = self.find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self.find_trace_from_node(node_to)
|
||||
for i in compute_from:
|
||||
if i in idx_to:
|
||||
compute_to.append(i)
|
||||
|
||||
def mark_idx_equal(self, idx1, idx2):
|
||||
self.idx_trace_equal.append((idx1, idx2))
|
||||
|
||||
def mark_computation(self, node, idx, dim):
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
for d in dim:
|
||||
cur_idx = input_node_idx_trace[d]
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
|
||||
def find_trace_from_node(self, node):
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict['idx'], node_dict['compute']
|
||||
|
||||
def find_idx_trace_from_node(self, node):
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx_trace = self.idx_trace_list[node_idx]['idx']
|
||||
return node_idx_trace
|
||||
|
||||
def assign_index_as_input(self, node, node_idx):
|
||||
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
def assign_all_index(self, node, node_idx):
|
||||
shape = node.meta['tensor_meta'].shape
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self.add_index())
|
||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
|
||||
def assign_transpose_index(self, node, node_idx):
|
||||
tranpose_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
|
||||
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
def assign_linear_index(self, node, node_idx):
|
||||
input_node, weight, bias = node.args
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[-1] = weight_idx_trace[1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
self.inherit_computation(input_node, node)
|
||||
self.mark_computation(node, node_idx, [-1])
|
||||
self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
|
||||
|
||||
if bias:
|
||||
bias_idx_trace = self.find_idx_trace_from_node(bias)
|
||||
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
|
||||
def assign_layernorm_index(self, node, idx):
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.mark_computation(node, idx, [-1, -2])
|
||||
|
||||
def trace_node_idx(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
if node.op == 'placeholder':
|
||||
self.assign_all_index(node, idx)
|
||||
elif node.op == 'call_method':
|
||||
if 'transpose' in node.name:
|
||||
self.assign_transpose_index(node, idx)
|
||||
elif 'view' in node.name:
|
||||
pass
|
||||
elif 'permute' in node.name:
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == 'call_function':
|
||||
if 'linear' in node.name:
|
||||
self.assign_linear_index(node, idx)
|
||||
elif 'getattr' in node.name:
|
||||
continue # get attr like shape
|
||||
elif 'getitem' in node.name:
|
||||
continue # get item in list
|
||||
else:
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == 'call_module':
|
||||
if 'layernorm' in node.name:
|
||||
self.assign_layernorm_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == 'get_attr':
|
||||
self.assign_all_index(node, idx) # get param
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
def _get_meta_node_size(x):
|
||||
x = x.meta['tensor_meta']
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
|
@ -557,6 +674,8 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
node_list = list(nodes)
|
||||
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||
_estimate_inference_mem(meta_graph)
|
||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||
node_index_tracer.trace_node_idx()
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(chunk_regions):
|
||||
|
|
Loading…
Reference in New Issue