mirror of https://github.com/hpcaitech/ColossalAI
988 lines
40 KiB
Python
988 lines
40 KiB
Python
import colossalai
|
|
import torch
|
|
import copy
|
|
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
|
|
|
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
|
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
|
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
|
|
CODEGEN_AVAILABLE = True
|
|
__all__ = ['ChunkCodeGen']
|
|
|
|
|
|
class NodeIndexTracer(object):
|
|
def __init__(self, gm) -> None:
|
|
self.gm = gm
|
|
self.nodes_list = list(gm.graph.nodes)
|
|
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
|
self.idx_trace_equal = []
|
|
self.idx_view_list = []
|
|
self.idx_count = -1
|
|
|
|
def add_index(self):
|
|
"""
|
|
Update the count and return it. To record the idx number.
|
|
|
|
Returns:
|
|
idx_count: int
|
|
"""
|
|
self.idx_count += 1
|
|
return self.idx_count
|
|
|
|
def inherit_computation(self, node_from, node_to):
|
|
"""
|
|
Inherit computed dim from node_from to node_to.
|
|
If a dim in node_from is marked as computed and exists in node_to,
|
|
still mark it as computed in node_to.
|
|
|
|
Args:
|
|
node_from (node): node to be inherited
|
|
node_to (node): new node to inherit
|
|
"""
|
|
_, compute_from = self.find_trace_from_node(node_from)
|
|
idx_to, compute_to = self.find_trace_from_node(node_to)
|
|
for i in compute_from:
|
|
if i in idx_to and i not in compute_to:
|
|
compute_to.append(i)
|
|
|
|
def mark_idx_equal(self, idx1, idx2):
|
|
"""
|
|
Mark 2 index to be equal.
|
|
|
|
Args:
|
|
idx1 (int): index count.
|
|
idx2 (int): index count.
|
|
"""
|
|
self.idx_trace_equal.append((idx1, idx2))
|
|
|
|
def mark_computation(self, node, idx, dim):
|
|
"""
|
|
Mark some dims of node as computed.
|
|
|
|
Args:
|
|
node (node)
|
|
idx (int): node index
|
|
dim (list or int): dims to be marked as computed
|
|
"""
|
|
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
|
if isinstance(dim, int):
|
|
dim = [dim]
|
|
for d in dim:
|
|
cur_idx = input_node_idx_trace[d]
|
|
if cur_idx not in self.idx_trace_list[idx]['compute']:
|
|
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
|
|
|
def find_trace_from_node(self, node):
|
|
"""
|
|
Find node idx and compute trace by the node.
|
|
|
|
Args:
|
|
node (node)
|
|
Returns:
|
|
idx (list): idx of the node
|
|
compute (list): computed idx of the node.
|
|
"""
|
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
|
node_dict = self.idx_trace_list[node_idx]
|
|
return node_dict['idx'], node_dict['compute']
|
|
|
|
def find_idx_trace_from_node(self, node):
|
|
"""
|
|
Find node idx trace by the node.
|
|
|
|
Args:
|
|
node (node)
|
|
Returns:
|
|
idx (list): idx of the node
|
|
"""
|
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
|
return self.idx_trace_list[node_idx]['idx']
|
|
|
|
def find_compute_trace_from_node(self, node):
|
|
"""
|
|
Find node compute trace by the node.
|
|
|
|
Args:
|
|
node (node)
|
|
Returns:
|
|
compute (list): computed idx of the node.
|
|
"""
|
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
|
return self.idx_trace_list[node_idx]['compute']
|
|
|
|
def assign_index_as_input(self, node, node_idx):
|
|
"""
|
|
Assign node's trace as its input node.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
|
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
|
|
|
def assign_all_index(self, node, node_idx):
|
|
"""
|
|
Add new index for all node's dims.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
shape = node.meta['tensor_meta'].shape
|
|
new_trace = []
|
|
for _ in shape:
|
|
new_trace.append(self.add_index())
|
|
self.idx_trace_list[node_idx]['idx'] = new_trace
|
|
|
|
def assign_transpose_index(self, node, node_idx):
|
|
"""
|
|
Assign index for transpose op.
|
|
1. swap input's dim according to transpose args
|
|
2. inherit input's computation
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
tranpose_dim = node.args[1:]
|
|
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
|
|
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
|
|
|
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
|
self.inherit_computation(node.args[0], node)
|
|
|
|
def assign_permute_index(self, node, node_idx):
|
|
"""
|
|
Assign index for permute op.
|
|
1. swap input's dim according to permute args
|
|
2. inherit input's computation
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
permute_dim = node.args[1:]
|
|
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
for idx, d in enumerate(permute_dim):
|
|
new_idx_trace[idx] = input_node_idx_trace[d]
|
|
|
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
|
self.inherit_computation(node.args[0], node)
|
|
|
|
def assign_linear_index(self, node, node_idx):
|
|
"""
|
|
Assign index for linear op.
|
|
1. copy trace from input node and change last index accroding to weight
|
|
2. mark equal for input node last index, weight first dim and bias dim.
|
|
3. inherit input's computation, mark computation for last dim.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
input_node, weight, bias = node.args
|
|
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
|
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
|
|
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
|
new_idx_trace[-1] = weight_idx_trace[1]
|
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
|
|
|
self.inherit_computation(input_node, node)
|
|
self.mark_computation(node, node_idx, [-1])
|
|
self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
|
|
|
|
if bias:
|
|
bias_idx_trace = self.find_idx_trace_from_node(bias)
|
|
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
|
|
|
def assign_matmul_index(self, node, node_idx):
|
|
"""
|
|
Assign index for matmul op.
|
|
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
|
|
2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
|
|
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
matmul_left, matmul_right = node.args
|
|
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
|
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
|
|
|
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
|
|
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
|
|
new_idx_trace[-1] = matmul_right_idx_trace[-1]
|
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
|
|
|
self.inherit_computation(matmul_left, node)
|
|
self.inherit_computation(matmul_right, node)
|
|
self.mark_computation(node, node_idx, [-1])
|
|
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
|
|
|
def assign_layernorm_index(self, node, idx):
|
|
"""
|
|
Assign index for layernorm op.
|
|
1. assign index as input node
|
|
2. inherit computation and mark last 2 dims as computed.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
self.assign_index_as_input(node, idx)
|
|
self.inherit_computation(node.args[0], node)
|
|
self.mark_computation(node, idx, [-1, -2])
|
|
|
|
def assign_elementwise_index(self, node, idx):
|
|
"""
|
|
Assign index for element-wise op (eg. relu sigmoid add mul).
|
|
1. assign index as input node
|
|
2. inherit computation from all input nodes.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
self.assign_index_as_input(node, idx)
|
|
for node_in in node.args:
|
|
if type(node_in) not in (int, float):
|
|
self.inherit_computation(node_in, node)
|
|
|
|
def assign_softmax_index(self, node, idx):
|
|
"""
|
|
Assign index for softmax op.
|
|
1. assign index as input node
|
|
2. inherit computation and mark softmax dim as computed.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
self.assign_index_as_input(node, idx)
|
|
self.inherit_computation(node.args[0], node)
|
|
self.mark_computation(node, idx, [node.kwargs['dim']])
|
|
|
|
def assign_view_reshape_index(self, node, node_idx):
|
|
"""
|
|
Assign index for view and reshape op.
|
|
1. get origin shape and target shape by meta info.
|
|
2. compute the real value of -1 in target shape.
|
|
3. determine changed dim, and assgin index for generated dim.
|
|
4. log changed dim and generated dim for restore
|
|
5. inherit computation.
|
|
6. TODO: look into view list to see whether the view is associated with other,
|
|
if so assgin equal dim according to previous view.
|
|
|
|
Args:
|
|
node (node)
|
|
node_idx (int)
|
|
"""
|
|
# get data, turn into number
|
|
origin_node = node.args[0]
|
|
origin_shape = origin_node.meta['tensor_meta'].shape
|
|
target_shape = []
|
|
for i in range(1, len(node.args)):
|
|
if isinstance(node.args[i], int):
|
|
target_shape.append(node.args[i])
|
|
else:
|
|
target_shape.append(node.args[i].meta['fwd_out'][0])
|
|
|
|
# compute the value of -1
|
|
if -1 in target_shape:
|
|
origin_product = 1
|
|
for i in origin_shape:
|
|
origin_product *= i
|
|
target_product = -1
|
|
for i in target_shape:
|
|
target_product *= i
|
|
shape_idx = target_shape.index(-1)
|
|
target_shape[shape_idx] = origin_product // target_product
|
|
|
|
# determine changed dim
|
|
len_diff = len(origin_shape) - len(target_shape)
|
|
if len_diff == 1:
|
|
# dim merge
|
|
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
|
dim_to = [dim_equal.index(False)]
|
|
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
|
elif len_diff == -1:
|
|
# dim expand
|
|
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
|
dim_from = [dim_equal.index(False)]
|
|
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
|
else:
|
|
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
|
|
|
|
# get new index
|
|
origin_trace = self.find_idx_trace_from_node(origin_node)
|
|
new_trace = copy.deepcopy(origin_trace)
|
|
dim_from.reverse()
|
|
for i in dim_from:
|
|
new_trace.pop(i)
|
|
for i in dim_to:
|
|
new_trace.insert(i, self.add_index())
|
|
self.idx_trace_list[node_idx]['idx'] = new_trace
|
|
|
|
# inherit computation
|
|
self.inherit_computation(origin_node, node)
|
|
compute_log = self.find_compute_trace_from_node(origin_node)
|
|
for i in dim_from:
|
|
if origin_trace[i] in compute_log:
|
|
for j in dim_to:
|
|
self.mark_computation(node, node_idx, [j])
|
|
break
|
|
|
|
# log view, not used now
|
|
view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
|
|
"dim_from": dim_from,
|
|
"idx_to": [new_trace[i] for i in dim_to],
|
|
"dim_to": dim_to}
|
|
self.idx_view_list.append(view_dict)
|
|
|
|
def trace_node_idx(self):
|
|
for idx, node in enumerate(self.nodes_list):
|
|
if node.op == 'placeholder':
|
|
self.assign_all_index(node, idx)
|
|
elif node.op == 'call_method':
|
|
if 'transpose' in node.name:
|
|
self.assign_transpose_index(node, idx)
|
|
elif 'permute' in node.name:
|
|
self.assign_permute_index(node, idx)
|
|
elif 'view' in node.name or 'reshape' in node.name:
|
|
self.assign_view_reshape_index(node, idx)
|
|
else:
|
|
raise NotImplementedError(node.name, "method not implemented yet!")
|
|
elif node.op == 'call_function':
|
|
if 'linear' in node.name:
|
|
self.assign_linear_index(node, idx)
|
|
elif 'matmul' in node.name:
|
|
self.assign_matmul_index(node, idx)
|
|
elif 'softmax' in node.name:
|
|
self.assign_softmax_index(node, idx)
|
|
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
|
self.assign_elementwise_index(node, idx)
|
|
elif 'getattr' in node.name:
|
|
continue # get attr like shape
|
|
elif 'getitem' in node.name:
|
|
continue # get item in list
|
|
else:
|
|
raise NotImplementedError(node.name, "function not implemented yet!")
|
|
elif node.op == 'call_module':
|
|
if any(n in node.name for n in ['layernorm', 'norm']):
|
|
self.assign_layernorm_index(node, idx)
|
|
else:
|
|
raise NotImplementedError(node.name, "module not implemented yet!")
|
|
elif node.op == 'get_attr':
|
|
self.assign_all_index(node, idx) # get param
|
|
elif node.op == 'output':
|
|
continue
|
|
else:
|
|
raise NotImplementedError(node.op, "op not implemented yet!")
|
|
|
|
|
|
def _get_meta_node_size(x):
|
|
x = x.meta['tensor_meta']
|
|
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
|
return x
|
|
|
|
|
|
def _get_output_node_size(n):
|
|
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
|
return activation_size(fwd_out)
|
|
|
|
|
|
def _get_delete_node_size(user, user_to_last_uses):
|
|
if user.op in ('placeholder', 'output'):
|
|
return 0
|
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
|
if len(nodes_to_delete):
|
|
delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
|
|
return delete_size
|
|
return 0
|
|
|
|
|
|
def _get_last_usr(nodes):
|
|
node_to_last_use: Dict[Node, Node] = {}
|
|
user_to_last_uses: Dict[Node, List[Node]] = {}
|
|
|
|
def register_last_uses(n: Node, user: Node):
|
|
if n not in node_to_last_use:
|
|
node_to_last_use[n] = user
|
|
user_to_last_uses.setdefault(user, []).append(n)
|
|
|
|
for node in reversed(nodes):
|
|
map_arg(node.args, lambda n: register_last_uses(n, node))
|
|
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
|
return user_to_last_uses
|
|
|
|
|
|
def _delete_free_var_from_last_use(user_to_last_uses):
|
|
for key, value in user_to_last_uses.items():
|
|
for n in value:
|
|
if n.op == 'placeholder':
|
|
user_to_last_uses[key].remove(n)
|
|
|
|
|
|
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
|
|
mem = 0
|
|
not_contiguous_ops = ['transpose', 'permute']
|
|
|
|
if node.op == 'call_function' and 'matmul' in node.name:
|
|
for n in node.args:
|
|
if n in not_contiguous_list:
|
|
# matmul won't change origin tensor, but create a tmp copy
|
|
mem += _get_output_node_size(n)
|
|
elif node.op == 'call_module':
|
|
for n in node.args:
|
|
if n in not_contiguous_list:
|
|
# module will just make origin tensor to contiguous
|
|
if delete:
|
|
not_contiguous_list.remove(n)
|
|
elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
|
|
if node not in not_contiguous_list:
|
|
not_contiguous_list.append(node)
|
|
elif any(i in node.args for i in not_contiguous_list):
|
|
if node not in not_contiguous_list:
|
|
not_contiguous_list.append(node)
|
|
|
|
return mem
|
|
|
|
|
|
def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
|
act_memory = 0.0
|
|
act_memory_peak_log = []
|
|
act_memory_after_node_log = []
|
|
not_contiguous_list = []
|
|
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
|
_delete_free_var_from_last_use(user_to_last_uses)
|
|
for node in gm.graph.nodes:
|
|
# if node is placeholder, just add the size of the node
|
|
if node.op == 'placeholder':
|
|
act_memory += _get_meta_node_size(node) / (1024 ** 2)
|
|
act_memory_peak_log.append(act_memory)
|
|
act_memory_after_node_log.append(act_memory)
|
|
# skip output
|
|
elif node.op == 'output':
|
|
continue
|
|
# node is an operation, calculate tmp, output node and delete node memory
|
|
else:
|
|
# forward memory
|
|
act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
|
|
act_memory += _get_output_node_size(node) / (1024 ** 2)
|
|
# record max act memory
|
|
act_memory_peak_log.append(act_memory)
|
|
# delete useless memory
|
|
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
|
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
|
act_memory_after_node_log.append(act_memory)
|
|
|
|
print("no chunk")
|
|
_print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
|
_print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
|
|
|
param_memory = parameter_size(gm)
|
|
return act_memory + param_memory, param_memory
|
|
|
|
|
|
def _get_chunk_ratio(node, chunk_dim, chunk_size):
|
|
shape = node.meta['tensor_meta'].shape
|
|
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
|
return chunk_ratio
|
|
|
|
|
|
def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
|
|
if user.op in ('placeholder', 'output'):
|
|
return 0
|
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
|
delete_size = 0
|
|
for n in nodes_to_delete:
|
|
node_idx = _find_idx_by_name(n.name, node_list)
|
|
if start_node <= node_idx < end_node:
|
|
delete_size += _get_output_node_size(n) * chunk_ratio
|
|
return delete_size
|
|
|
|
|
|
def _print_mem_log(log, nodes, title=None):
|
|
if title:
|
|
print(title)
|
|
for idx, (l, n) in enumerate(zip(log, nodes)):
|
|
print("%s:%.2f \t" % (n.name, l), end='')
|
|
if (idx + 1) % 3 == 0:
|
|
print("")
|
|
print("\n")
|
|
|
|
|
|
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
|
act_memory = 0.0
|
|
act_memory_peak_log = []
|
|
act_memory_after_node_log = []
|
|
not_contiguous_list = []
|
|
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
|
_delete_free_var_from_last_use(user_to_last_uses)
|
|
within_chunk = False
|
|
region_idx = 0
|
|
chunk_ratio = 1 # use it to estimate chunk mem
|
|
node_list = list(gm.graph.nodes)
|
|
|
|
for idx, node in enumerate(node_list):
|
|
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
|
if idx in start_nodes:
|
|
within_chunk = True
|
|
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
|
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
|
|
|
|
# if node is placeholder, just add the size of the node
|
|
if node.op == 'placeholder':
|
|
act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
|
|
act_memory_peak_log.append(act_memory)
|
|
# skip output
|
|
elif node.op == 'output':
|
|
continue
|
|
# node is an operation, calculate tmp, output node and delete node memory
|
|
else:
|
|
# forward memory
|
|
# TODO: permute will create a tmp copy if not contiguous
|
|
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
|
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
|
# record max act memory
|
|
act_memory_peak_log.append(act_memory)
|
|
# delete useless memory
|
|
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
|
if within_chunk:
|
|
act_memory -= _get_chunk_delete_node_size(
|
|
node, user_to_last_uses, chunk_ratio, node_list,
|
|
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
|
|
else:
|
|
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
|
|
|
if idx in end_nodes:
|
|
act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
|
within_chunk = False
|
|
chunk_ratio = 1
|
|
region_idx += 1
|
|
|
|
act_memory_after_node_log.append(act_memory)
|
|
|
|
print("chunk")
|
|
_print_mem_log(act_memory_peak_log, node_list, "peak")
|
|
_print_mem_log(act_memory_after_node_log, node_list, "after")
|
|
|
|
param_memory = parameter_size(gm)
|
|
return act_memory + param_memory, param_memory
|
|
|
|
|
|
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
|
new_shape = "["
|
|
for idx, i in enumerate(shape):
|
|
if idx == chunk_dim:
|
|
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
|
|
else:
|
|
new_shape += ":"
|
|
new_shape += ", "
|
|
new_shape = new_shape[:-2] + "]"
|
|
return new_shape
|
|
|
|
|
|
def _get_first_non_single_dim(shape):
|
|
for idx, i in enumerate(shape):
|
|
if i == 1:
|
|
continue
|
|
else:
|
|
return idx
|
|
raise RuntimeError("can not get first non single dim for shape", shape)
|
|
|
|
|
|
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
|
|
if len(chunk_input_meta) == 1:
|
|
node = chunk_input_meta[0]
|
|
node_shape = node.meta['tensor_meta'].shape
|
|
chunk_dim = _get_first_non_single_dim(node_shape)
|
|
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
|
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
|
|
|
|
context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
|
|
out_shape, node.name, node.name, chunk_size)
|
|
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
|
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
|
else:
|
|
raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
|
|
return context
|
|
|
|
|
|
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
|
|
chunk_inputs_name = chunk_inputs[0].name
|
|
chunk_outputs_name = chunk_outputs.name
|
|
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
|
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
|
|
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
|
|
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
|
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
|
|
|
context += chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
|
|
|
# determine if its the last use for chunk input
|
|
users_name = list(chunk_inputs[0].users.keys())
|
|
if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
|
|
context += "; %s = None" % chunk_inputs_name
|
|
|
|
context += "\n"
|
|
return context
|
|
|
|
|
|
def _find_input_and_output_nodes(nodes: List[Node]):
|
|
"""
|
|
Find the input and output node names which are not found in the given list of nodes.
|
|
"""
|
|
input_nodes = []
|
|
output_nodes = []
|
|
|
|
# if a node has an input node which is not in the node list
|
|
# we treat that input node as the input of the checkpoint function
|
|
for node in nodes:
|
|
for input_node in node._input_nodes.keys():
|
|
node_repr = repr(input_node)
|
|
if input_node not in nodes and node_repr not in input_nodes:
|
|
input_nodes.append(input_node)
|
|
|
|
# if a node has a user node which is not in the node list
|
|
# we treat that user node as the node receiving the current node output
|
|
for node in nodes:
|
|
for output_node in node.users.keys():
|
|
node_repr = repr(node)
|
|
if output_node not in nodes and node_repr not in output_nodes:
|
|
output_nodes.append(output_node)
|
|
|
|
return input_nodes, output_nodes
|
|
|
|
|
|
def _find_idx_by_name(name, nodes_list):
|
|
for idx, node in enumerate(nodes_list):
|
|
if node.name == name:
|
|
return idx
|
|
raise RuntimeError("name %s not found in node list" % name)
|
|
|
|
|
|
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
|
"""Emit code with nested activation checkpoint
|
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
|
this function to emit the activation checkpoint codes.
|
|
|
|
Args:
|
|
body: forward code
|
|
ckpt_func: checkpoint functions code
|
|
nodes: graph.nodes
|
|
emit_node_func: function to emit node
|
|
delete_unused_value_func: function to remove the unused value
|
|
"""
|
|
|
|
# find the offload regions
|
|
chunk_regions = [(58, 62)]
|
|
chunk_starts = [item[0] for item in chunk_regions]
|
|
chunk_ends = [item[1] for item in chunk_regions]
|
|
chunk_inputs = []
|
|
chunk_outputs = []
|
|
within_chunk_region = False
|
|
|
|
node_list = list(nodes)
|
|
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
|
_estimate_inference_mem(meta_graph)
|
|
node_index_tracer = NodeIndexTracer(meta_graph)
|
|
node_index_tracer.trace_node_idx()
|
|
|
|
# find the input and output var names for each offload region
|
|
for idx, (start, end) in enumerate(chunk_regions):
|
|
offload_node_list = node_list[start:end + 1]
|
|
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
|
chunk_inputs.append(inputs)
|
|
chunk_outputs.append(outputs)
|
|
chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
|
|
chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
|
|
chunk_inputs_names = []
|
|
for i in chunk_inputs:
|
|
for j in i:
|
|
chunk_inputs_names.append(j.name)
|
|
|
|
# this flag is to prevent repeated insert of save tensors
|
|
# hooks definition in ckpt_func
|
|
node_idx = 0
|
|
region_idx = 0
|
|
while node_idx < len(node_list):
|
|
node = node_list[node_idx]
|
|
|
|
if node_idx in chunk_starts:
|
|
within_chunk_region = True
|
|
|
|
# add for loop
|
|
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
|
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
|
|
|
|
if within_chunk_region:
|
|
emit_node_func(node, body)
|
|
# replace input var with chunk var
|
|
if node_idx in chunk_starts:
|
|
body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
|
|
body[-1] = ' ' + body[-1]
|
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
|
|
|
else:
|
|
emit_node_func(node, body)
|
|
if node_idx not in chunk_inputs:
|
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
|
|
|
if node_idx in chunk_ends:
|
|
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
|
|
within_chunk_region = False
|
|
region_idx += 1
|
|
|
|
node_idx += 1
|
|
|
|
|
|
if CODEGEN_AVAILABLE:
|
|
|
|
class ChunkCodeGen(CodeGen):
|
|
def __init__(self, meta_graph):
|
|
super().__init__()
|
|
self.meta_graph = meta_graph
|
|
self.meta_node = list(meta_graph.graph.nodes)
|
|
|
|
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
|
free_vars: List[str] = []
|
|
body: List[str] = []
|
|
globals_: Dict[str, Any] = {}
|
|
wrapped_fns: Dict[str, None] = {}
|
|
|
|
# Wrap string in list to pass by reference
|
|
maybe_return_annotation: List[str] = ['']
|
|
|
|
def add_global(name_hint: str, obj: Any):
|
|
"""Add an obj to be tracked as a global.
|
|
|
|
We call this for names that reference objects external to the
|
|
Graph, like functions or types.
|
|
|
|
Returns: the global name that should be used to reference 'obj' in generated source.
|
|
"""
|
|
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
|
# HACK: workaround for how torch custom ops are registered. We
|
|
# can't import them like normal modules so they must retain their
|
|
# fully qualified name.
|
|
return _get_qualified_name(obj)
|
|
|
|
# normalize the name hint to get a proper identifier
|
|
global_name = namespace.create_name(name_hint, obj)
|
|
|
|
if global_name in globals_:
|
|
assert globals_[global_name] is obj
|
|
return global_name
|
|
globals_[global_name] = obj
|
|
return global_name
|
|
|
|
# set _custom_builtins here so that we needn't import colossalai in forward
|
|
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
|
|
|
# Pre-fill the globals table with registered builtins.
|
|
for name, (_, obj) in _custom_builtins.items():
|
|
add_global(name, obj)
|
|
|
|
def type_repr(o: Any):
|
|
if o == ():
|
|
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
|
return '()'
|
|
|
|
typename = _type_repr(o)
|
|
|
|
if hasattr(o, '__origin__'):
|
|
# This is a generic type, e.g. typing.List[torch.Tensor]
|
|
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
|
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
|
|
|
if hasattr(o, '__args__'):
|
|
# Assign global names for each of the inner type variables.
|
|
args = [type_repr(arg) for arg in o.__args__]
|
|
|
|
if len(args) == 0:
|
|
# Bare type, such as `typing.Tuple` with no subscript
|
|
# This code-path used in Python < 3.9
|
|
return origin_typename
|
|
|
|
return f'{origin_typename}[{",".join(args)}]'
|
|
else:
|
|
# Bare type, such as `typing.Tuple` with no subscript
|
|
# This code-path used in Python 3.9+
|
|
return origin_typename
|
|
|
|
# Common case: this is a regular module name like 'foo.bar.baz'
|
|
return add_global(typename, o)
|
|
|
|
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
|
|
|
def _get_repr(arg):
|
|
# Handle NamedTuples (if it has `_fields`) via add_global.
|
|
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
|
|
qualified_name = _get_qualified_name(type(arg))
|
|
global_name = add_global(qualified_name, type(arg))
|
|
return f"{global_name}{repr(tuple(arg))}"
|
|
return repr(arg)
|
|
|
|
args_s = ', '.join(_get_repr(a) for a in args)
|
|
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
|
|
if args_s and kwargs_s:
|
|
return f'{args_s}, {kwargs_s}'
|
|
return args_s or kwargs_s
|
|
|
|
# Run through reverse nodes and record the first instance of a use
|
|
# of a given node. This represents the *last* use of the node in the
|
|
# execution order of the program, which we will use to free unused
|
|
# values
|
|
node_to_last_use: Dict[Node, Node] = {}
|
|
user_to_last_uses: Dict[Node, List[Node]] = {}
|
|
|
|
def register_last_uses(n: Node, user: Node):
|
|
if n not in node_to_last_use:
|
|
node_to_last_use[n] = user
|
|
user_to_last_uses.setdefault(user, []).append(n)
|
|
|
|
for node in reversed(nodes):
|
|
map_arg(node.args, lambda n: register_last_uses(n, node))
|
|
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
|
|
|
_delete_free_var_from_last_use(user_to_last_uses)
|
|
|
|
# NOTE: we add a variable to distinguish body and ckpt_func
|
|
def delete_unused_values(user: Node, body, to_keep=[]):
|
|
"""
|
|
Delete values after their last use. This ensures that values that are
|
|
not used in the remainder of the code are freed and the memory usage
|
|
of the code is optimal.
|
|
"""
|
|
if user.op == 'placeholder':
|
|
return
|
|
if user.op == 'output':
|
|
body.append('\n')
|
|
return
|
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
|
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
|
if len(nodes_to_delete):
|
|
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
|
body.append(f'; {to_delete_str}\n')
|
|
else:
|
|
body.append('\n')
|
|
|
|
# NOTE: we add a variable to distinguish body and ckpt_func
|
|
def emit_node(node: Node, body):
|
|
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
|
if node.op == 'placeholder':
|
|
assert isinstance(node.target, str)
|
|
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
|
|
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
|
|
raw_name = node.target.replace('*', '')
|
|
if raw_name != repr(node):
|
|
body.append(f'{repr(node)} = {raw_name}\n')
|
|
return
|
|
elif node.op == 'call_method':
|
|
assert isinstance(node.target, str)
|
|
body.append(
|
|
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
|
|
f'({_format_args(node.args[1:], node.kwargs)})')
|
|
return
|
|
elif node.op == 'call_function':
|
|
assert callable(node.target)
|
|
# pretty print operators
|
|
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
|
|
assert isinstance(node.args, tuple)
|
|
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
|
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
|
return
|
|
|
|
# pretty print inplace operators; required for jit.script to work properly
|
|
# not currently supported in normal FX graphs, but generated by torchdynamo
|
|
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
|
|
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
|
|
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
|
|
return
|
|
|
|
qualified_name = _get_qualified_name(node.target)
|
|
global_name = add_global(qualified_name, node.target)
|
|
# special case for getattr: node.args could be 2-argument or 3-argument
|
|
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
|
if global_name == 'getattr' and \
|
|
isinstance(node.args, tuple) and \
|
|
isinstance(node.args[1], str) and \
|
|
node.args[1].isidentifier() and \
|
|
len(node.args) == 2:
|
|
body.append(
|
|
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
|
|
return
|
|
body.append(
|
|
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
|
|
if node.meta.get('is_wrapped', False):
|
|
wrapped_fns.setdefault(global_name)
|
|
return
|
|
elif node.op == 'call_module':
|
|
assert isinstance(node.target, str)
|
|
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
|
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
|
return
|
|
elif node.op == 'get_attr':
|
|
assert isinstance(node.target, str)
|
|
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
|
|
return
|
|
elif node.op == 'output':
|
|
if node.type is not None:
|
|
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
|
body.append(self.generate_output(node.args[0]))
|
|
return
|
|
raise NotImplementedError(f'node: {node.op} {node.target}')
|
|
|
|
# Modified for activation checkpointing
|
|
ckpt_func = []
|
|
|
|
# if any node has a list of labels for activation_checkpoint, we
|
|
# will use nested type of activation checkpoint codegen
|
|
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph)
|
|
|
|
if len(body) == 0:
|
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
|
# have been emitted. To continue to have valid Python code, emit a
|
|
# single pass statement
|
|
body.append('pass\n')
|
|
|
|
if len(wrapped_fns) > 0:
|
|
wrap_name = add_global('wrap', torch.fx.wrap)
|
|
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
|
else:
|
|
wrap_stmts = ''
|
|
|
|
if self._body_transformer:
|
|
body = self._body_transformer(body)
|
|
|
|
for name, value in self.additional_globals():
|
|
add_global(name, value)
|
|
|
|
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
|
# in forward function
|
|
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
|
prologue = ''.join(ckpt_func) + prologue
|
|
prologue = prologue
|
|
|
|
code = ''.join(body)
|
|
code = '\n'.join(' ' + line for line in code.split('\n'))
|
|
fn_code = f"""
|
|
{wrap_stmts}
|
|
|
|
{prologue}
|
|
{code}"""
|
|
print(fn_code)
|
|
return PythonCode(fn_code, globals_)
|