mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
7e2bd1e428
commit
fad3b6d1a6
216
chunk_codegen.py
216
chunk_codegen.py
|
@ -10,6 +10,13 @@ CODEGEN_AVAILABLE = True
|
|||
__all__ = ['ChunkCodeGen']
|
||||
|
||||
|
||||
def _delete_free_var_from_last_use(user_to_last_uses):
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == 'placeholder':
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
class NodeIndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
|
@ -19,7 +26,7 @@ class NodeIndexTracer(object):
|
|||
self.idx_view_list = []
|
||||
self.idx_count = -1
|
||||
|
||||
def add_index(self):
|
||||
def _add_index(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
||||
|
@ -29,7 +36,7 @@ class NodeIndexTracer(object):
|
|||
self.idx_count += 1
|
||||
return self.idx_count
|
||||
|
||||
def inherit_computation(self, node_from, node_to):
|
||||
def _inherit_computation(self, node_from, node_to):
|
||||
"""
|
||||
Inherit computed dim from node_from to node_to.
|
||||
If a dim in node_from is marked as computed and exists in node_to,
|
||||
|
@ -39,13 +46,13 @@ class NodeIndexTracer(object):
|
|||
node_from (node): node to be inherited
|
||||
node_to (node): new node to inherit
|
||||
"""
|
||||
_, compute_from = self.find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self.find_trace_from_node(node_to)
|
||||
_, compute_from = self._find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self._find_trace_from_node(node_to)
|
||||
for i in compute_from:
|
||||
if i in idx_to and i not in compute_to:
|
||||
compute_to.append(i)
|
||||
|
||||
def mark_idx_equal(self, idx1, idx2):
|
||||
def _mark_idx_equal(self, idx1, idx2):
|
||||
"""
|
||||
Mark 2 index to be equal.
|
||||
|
||||
|
@ -55,7 +62,7 @@ class NodeIndexTracer(object):
|
|||
"""
|
||||
self.idx_trace_equal.append((idx1, idx2))
|
||||
|
||||
def mark_computation(self, node, idx, dim):
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
||||
|
@ -64,7 +71,7 @@ class NodeIndexTracer(object):
|
|||
idx (int): node index
|
||||
dim (list or int): dims to be marked as computed
|
||||
"""
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
for d in dim:
|
||||
|
@ -72,7 +79,7 @@ class NodeIndexTracer(object):
|
|||
if cur_idx not in self.idx_trace_list[idx]['compute']:
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
|
||||
def find_trace_from_node(self, node):
|
||||
def _find_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
||||
|
@ -86,7 +93,7 @@ class NodeIndexTracer(object):
|
|||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict['idx'], node_dict['compute']
|
||||
|
||||
def find_idx_trace_from_node(self, node):
|
||||
def _find_idx_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
||||
|
@ -98,7 +105,7 @@ class NodeIndexTracer(object):
|
|||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['idx']
|
||||
|
||||
def find_compute_trace_from_node(self, node):
|
||||
def _find_compute_trace_from_node(self, node):
|
||||
"""
|
||||
Find node compute trace by the node.
|
||||
|
||||
|
@ -110,7 +117,7 @@ class NodeIndexTracer(object):
|
|||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['compute']
|
||||
|
||||
def assign_index_as_input(self, node, node_idx):
|
||||
def _assign_index_as_input(self, node, node_idx):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
|
@ -124,7 +131,7 @@ class NodeIndexTracer(object):
|
|||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
def assign_all_index(self, node, node_idx):
|
||||
def _assign_all_index(self, node, node_idx):
|
||||
"""
|
||||
Add new index for all node's dims.
|
||||
|
||||
|
@ -135,10 +142,10 @@ class NodeIndexTracer(object):
|
|||
shape = node.meta['tensor_meta'].shape
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self.add_index())
|
||||
new_trace.append(self._add_index())
|
||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
|
||||
def assign_transpose_index(self, node, node_idx):
|
||||
def _assign_transpose_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
|
@ -149,16 +156,16 @@ class NodeIndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
tranpose_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
|
||||
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_permute_index(self, node, node_idx):
|
||||
def _assign_permute_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
|
@ -169,16 +176,16 @@ class NodeIndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
new_idx_trace[idx] = input_node_idx_trace[d]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_linear_index(self, node, node_idx):
|
||||
def _assign_linear_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for linear op.
|
||||
1. copy trace from input node and change last index accroding to weight
|
||||
|
@ -190,22 +197,22 @@ class NodeIndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
input_node, weight, bias = node.args
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[-1] = weight_idx_trace[1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
self.inherit_computation(input_node, node)
|
||||
self.mark_computation(node, node_idx, [-1])
|
||||
self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
|
||||
self._inherit_computation(input_node, node)
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
|
||||
|
||||
if bias:
|
||||
bias_idx_trace = self.find_idx_trace_from_node(bias)
|
||||
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
bias_idx_trace = self._find_idx_trace_from_node(bias)
|
||||
self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
|
||||
def assign_matmul_index(self, node, node_idx):
|
||||
def _assign_matmul_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for matmul op.
|
||||
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
|
||||
|
@ -217,20 +224,20 @@ class NodeIndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
matmul_left, matmul_right = node.args
|
||||
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
||||
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
||||
matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left)
|
||||
matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right)
|
||||
|
||||
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
|
||||
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
|
||||
new_idx_trace[-1] = matmul_right_idx_trace[-1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
self.inherit_computation(matmul_left, node)
|
||||
self.inherit_computation(matmul_right, node)
|
||||
self.mark_computation(node, node_idx, [-1])
|
||||
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||
self._inherit_computation(matmul_left, node)
|
||||
self._inherit_computation(matmul_right, node)
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||
|
||||
def assign_layernorm_index(self, node, idx):
|
||||
def _assign_layernorm_index(self, node, idx):
|
||||
"""
|
||||
Assign index for layernorm op.
|
||||
1. assign index as input node
|
||||
|
@ -240,11 +247,11 @@ class NodeIndexTracer(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self.mark_computation(node, idx, [-1, -2])
|
||||
self._assign_index_as_input(node, idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._mark_computation(node, idx, [-1, -2])
|
||||
|
||||
def assign_elementwise_index(self, node, idx):
|
||||
def _assign_elementwise_index(self, node, idx):
|
||||
"""
|
||||
Assign index for element-wise op (eg. relu sigmoid add mul).
|
||||
1. assign index as input node
|
||||
|
@ -254,12 +261,12 @@ class NodeIndexTracer(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
self._assign_index_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) not in (int, float):
|
||||
self.inherit_computation(node_in, node)
|
||||
self._inherit_computation(node_in, node)
|
||||
|
||||
def assign_softmax_index(self, node, idx):
|
||||
def _assign_softmax_index(self, node, idx):
|
||||
"""
|
||||
Assign index for softmax op.
|
||||
1. assign index as input node
|
||||
|
@ -269,11 +276,11 @@ class NodeIndexTracer(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self.mark_computation(node, idx, [node.kwargs['dim']])
|
||||
self._assign_index_as_input(node, idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._mark_computation(node, idx, [node.kwargs['dim']])
|
||||
|
||||
def assign_view_reshape_index(self, node, node_idx):
|
||||
def _assign_view_reshape_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
|
@ -325,22 +332,22 @@ class NodeIndexTracer(object):
|
|||
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new index
|
||||
origin_trace = self.find_idx_trace_from_node(origin_node)
|
||||
origin_trace = self._find_idx_trace_from_node(origin_node)
|
||||
new_trace = copy.deepcopy(origin_trace)
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
new_trace.pop(i)
|
||||
for i in dim_to:
|
||||
new_trace.insert(i, self.add_index())
|
||||
new_trace.insert(i, self._add_index())
|
||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
|
||||
# inherit computation
|
||||
self.inherit_computation(origin_node, node)
|
||||
compute_log = self.find_compute_trace_from_node(origin_node)
|
||||
self._inherit_computation(origin_node, node)
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self.mark_computation(node, node_idx, [j])
|
||||
self._mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view, not used now
|
||||
|
@ -353,25 +360,25 @@ class NodeIndexTracer(object):
|
|||
def trace_node_idx(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
if node.op == 'placeholder':
|
||||
self.assign_all_index(node, idx)
|
||||
self._assign_all_index(node, idx)
|
||||
elif node.op == 'call_method':
|
||||
if 'transpose' in node.name:
|
||||
self.assign_transpose_index(node, idx)
|
||||
self._assign_transpose_index(node, idx)
|
||||
elif 'permute' in node.name:
|
||||
self.assign_permute_index(node, idx)
|
||||
self._assign_permute_index(node, idx)
|
||||
elif 'view' in node.name or 'reshape' in node.name:
|
||||
self.assign_view_reshape_index(node, idx)
|
||||
self._assign_view_reshape_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == 'call_function':
|
||||
if 'linear' in node.name:
|
||||
self.assign_linear_index(node, idx)
|
||||
self._assign_linear_index(node, idx)
|
||||
elif 'matmul' in node.name:
|
||||
self.assign_matmul_index(node, idx)
|
||||
self._assign_matmul_index(node, idx)
|
||||
elif 'softmax' in node.name:
|
||||
self.assign_softmax_index(node, idx)
|
||||
self._assign_softmax_index(node, idx)
|
||||
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
||||
self.assign_elementwise_index(node, idx)
|
||||
self._assign_elementwise_index(node, idx)
|
||||
elif 'getattr' in node.name:
|
||||
continue # get attr like shape
|
||||
elif 'getitem' in node.name:
|
||||
|
@ -380,39 +387,40 @@ class NodeIndexTracer(object):
|
|||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == 'call_module':
|
||||
if any(n in node.name for n in ['layernorm', 'norm']):
|
||||
self.assign_layernorm_index(node, idx)
|
||||
self._assign_layernorm_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == 'get_attr':
|
||||
self.assign_all_index(node, idx) # get param
|
||||
self._assign_all_index(node, idx) # get param
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
|
||||
def _get_meta_node_size(x):
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
x = x.meta['tensor_meta']
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
|
||||
|
||||
def _get_output_node_size(n):
|
||||
def _get_output_node_size(self, n):
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
return activation_size(fwd_out)
|
||||
|
||||
|
||||
def _get_delete_node_size(user, user_to_last_uses):
|
||||
def _get_delete_node_size(self, user, user_to_last_uses):
|
||||
if user.op in ('placeholder', 'output'):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
|
||||
delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete])
|
||||
return delete_size
|
||||
return 0
|
||||
|
||||
|
||||
def _get_last_usr(nodes):
|
||||
def _get_last_usr(self, nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
|
@ -426,15 +434,7 @@ def _get_last_usr(nodes):
|
|||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
return user_to_last_uses
|
||||
|
||||
|
||||
def _delete_free_var_from_last_use(user_to_last_uses):
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == 'placeholder':
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
|
||||
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
|
||||
mem = 0
|
||||
not_contiguous_ops = ['transpose', 'permute']
|
||||
|
||||
|
@ -442,7 +442,7 @@ def _get_contiguous_memory(node, not_contiguous_list, delete=False):
|
|||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
mem += _get_output_node_size(n)
|
||||
mem += self._get_output_node_size(n)
|
||||
elif node.op == 'call_module':
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
|
@ -458,18 +458,17 @@ def _get_contiguous_memory(node, not_contiguous_list, delete=False):
|
|||
|
||||
return mem
|
||||
|
||||
|
||||
def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
||||
def estimate_inference_mem(self, gm: torch.fx.GraphModule):
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
for node in gm.graph.nodes:
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node) / (1024 ** 2)
|
||||
act_memory += self._get_meta_node_size(node) / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
# skip output
|
||||
|
@ -478,30 +477,30 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
|||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
|
||||
act_memory += _get_output_node_size(node) / (1024 ** 2)
|
||||
act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
|
||||
act_memory += self._get_output_node_size(node) / (1024 ** 2)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
print("no chunk")
|
||||
_print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
||||
_print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
||||
self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
||||
self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return act_memory + param_memory, param_memory
|
||||
|
||||
|
||||
def _get_chunk_ratio(node, chunk_dim, chunk_size):
|
||||
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
||||
shape = node.meta['tensor_meta'].shape
|
||||
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
||||
return chunk_ratio
|
||||
|
||||
|
||||
def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
|
||||
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
|
||||
if user.op in ('placeholder', 'output'):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
|
@ -509,11 +508,11 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
|
|||
for n in nodes_to_delete:
|
||||
node_idx = _find_idx_by_name(n.name, node_list)
|
||||
if start_node <= node_idx < end_node:
|
||||
delete_size += _get_output_node_size(n) * chunk_ratio
|
||||
delete_size += self._get_output_node_size(n) * chunk_ratio
|
||||
return delete_size
|
||||
|
||||
|
||||
def _print_mem_log(log, nodes, title=None):
|
||||
def _print_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
|
@ -523,12 +522,12 @@ def _print_mem_log(log, nodes, title=None):
|
|||
print("\n")
|
||||
|
||||
|
||||
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
||||
def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
within_chunk = False
|
||||
region_idx = 0
|
||||
|
@ -539,12 +538,12 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if idx in start_nodes:
|
||||
within_chunk = True
|
||||
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
||||
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
|
||||
chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
||||
act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
|
@ -553,21 +552,21 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
else:
|
||||
# forward memory
|
||||
# TODO: permute will create a tmp copy if not contiguous
|
||||
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
||||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
||||
if within_chunk:
|
||||
act_memory -= _get_chunk_delete_node_size(
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node, user_to_last_uses, chunk_ratio, node_list,
|
||||
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
|
||||
else:
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
|
||||
if idx in end_nodes:
|
||||
act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
within_chunk = False
|
||||
chunk_ratio = 1
|
||||
region_idx += 1
|
||||
|
@ -575,8 +574,8 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
print("chunk")
|
||||
_print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
_print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return act_memory + param_memory, param_memory
|
||||
|
@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
within_chunk_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||
_estimate_inference_mem(meta_graph)
|
||||
memory_estimator = MemoryEstimator()
|
||||
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||
memory_estimator.estimate_inference_mem(meta_graph)
|
||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||
node_index_tracer.trace_node_idx()
|
||||
|
||||
|
|
Loading…
Reference in New Issue