finish memory estimation

pull/2364/head
oahzxl 2022-11-11 15:43:03 +08:00
parent 22f9c60b6b
commit d7634af5c0
2 changed files with 80 additions and 47 deletions

View File

@ -55,15 +55,49 @@ def _get_last_usr(nodes):
return user_to_last_uses return user_to_last_uses
def _delete_free_var_from_last_use(user_to_last_uses):
for key, value in user_to_last_uses.items():
for n in value:
if n.op == 'placeholder':
user_to_last_uses[key].remove(n)
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ['transpose', 'permute']
if node.op == 'call_function' and 'matmul' in node.name:
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += _get_output_node_size(n)
elif node.op == 'call_module':
for n in node.args:
if n in not_contiguous_list:
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
elif any(i in node.args for i in not_contiguous_list):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
def _estimate_inference_mem(gm: torch.fx.GraphModule): def _estimate_inference_mem(gm: torch.fx.GraphModule):
act_memory = 0 act_memory = 0.0
act_memory_peak_log = [] act_memory_peak_log = []
act_memory_after_node_log = [] act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
for node in gm.graph.nodes: for node in gm.graph.nodes:
# if node is placeholder, just add the size of the node # if node is placeholder, just add the size of the node
if node.op == 'placeholder': if node.op == 'placeholder':
act_memory += _get_meta_node_size(node) act_memory += _get_meta_node_size(node) / (1024 ** 2)
act_memory_peak_log.append(act_memory) act_memory_peak_log.append(act_memory)
act_memory_after_node_log.append(act_memory) act_memory_after_node_log.append(act_memory)
# skip output # skip output
@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
# node is an operation, calculate tmp, output node and delete node memory # node is an operation, calculate tmp, output node and delete node memory
else: else:
# forward memory # forward memory
act_memory += calculate_fwd_tmp(node) act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
# act_memory += calculate_fwd_out(node) act_memory += _get_output_node_size(node) / (1024 ** 2)
act_memory += _get_output_node_size(node)
# record max act memory # record max act memory
act_memory_peak_log.append(act_memory) act_memory_peak_log.append(act_memory)
# delete useless memory # delete useless memory
act_memory -= calculate_fwd_tmp(node) act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= _get_delete_node_size(node, user_to_last_uses) act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
act_memory_after_node_log.append(act_memory) act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
print("no chunk") print("no chunk")
_print_mem_log(act_memory_peak_log, "peak") _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
_print_mem_log(act_memory_after_node_log, "after") _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
param_memory = parameter_size(gm) param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) return act_memory + param_memory, param_memory
def _get_chunk_ratio(node, chunk_dim, chunk_size): def _get_chunk_ratio(node, chunk_dim, chunk_size):
@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
return delete_size return delete_size
def _print_mem_log(log, title=None): def _print_mem_log(log, nodes, title=None):
if title: if title:
print("%-8s" % title, end=' ') print(title)
for i in log: for idx, (l, n) in enumerate(zip(log, nodes)):
print("%.2f " % i, end='') print("%s:%.2f \t" % (n.name, l), end='')
if (idx + 1) % 3 == 0:
print("") print("")
print("\n")
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
act_memory = 0 act_memory = 0.0
act_memory_peak_log = [] act_memory_peak_log = []
act_memory_after_node_log = [] act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
within_chunk = False within_chunk = False
region_idx = 0 region_idx = 0
chunk_ratio = 1 # use it to estimate chunk mem chunk_ratio = 1 # use it to estimate chunk mem
@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
if idx in start_nodes: if idx in start_nodes:
within_chunk = True within_chunk = True
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
# if node is placeholder, just add the size of the node # if node is placeholder, just add the size of the node
if node.op == 'placeholder': if node.op == 'placeholder':
act_memory += _get_meta_node_size(node) * chunk_ratio act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory_peak_log.append(act_memory) act_memory_peak_log.append(act_memory)
# skip output # skip output
elif node.op == 'output': elif node.op == 'output':
@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# node is an operation, calculate tmp, output node and delete node memory # node is an operation, calculate tmp, output node and delete node memory
else: else:
# forward memory # forward memory
act_memory += calculate_fwd_tmp(node) * chunk_ratio act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
# act_memory += calculate_fwd_out(node) act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory += _get_output_node_size(node) * chunk_ratio
# record max act memory # record max act memory
act_memory_peak_log.append(act_memory) act_memory_peak_log.append(act_memory)
# delete useless memory # delete useless memory
act_memory -= calculate_fwd_tmp(node) * chunk_ratio act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
if within_chunk: if within_chunk:
act_memory -= _get_chunk_delete_node_size( act_memory -= _get_chunk_delete_node_size(
node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx]) node, user_to_last_uses, chunk_ratio, node_list,
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
else: else:
act_memory -= _get_delete_node_size(node, user_to_last_uses) act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
if idx in end_nodes: if idx in end_nodes:
act_memory -= _get_output_node_size(node) * chunk_ratio act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
within_chunk = False within_chunk = False
chunk_ratio = 1 chunk_ratio = 1
region_idx += 1 region_idx += 1
act_memory_after_node_log.append(act_memory) act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
print("chunk") print("chunk")
_print_mem_log(act_memory_peak_log, "peak") _print_mem_log(act_memory_peak_log, node_list, "peak")
_print_mem_log(act_memory_after_node_log, "after") _print_mem_log(act_memory_after_node_log, node_list, "after")
param_memory = parameter_size(gm) param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) return act_memory + param_memory, param_memory
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
""" """
# find the offload regions # find the offload regions
chunk_regions = [(2, 6)] chunk_regions = [(58, 62)]
chunk_starts = [item[0] for item in chunk_regions] chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = [] chunk_inputs = []
@ -684,6 +715,8 @@ if CODEGEN_AVAILABLE:
map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
_delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]): def delete_unused_values(user: Node, body, to_keep=[]):
""" """

View File

@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2 now_mem = torch.cuda.memory_allocated() / 1024**2
# max_mem = torch.cuda.max_memory_allocated() / 1024**2 with torch.no_grad():
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)) node0 = node.clone()
# with torch.no_grad(): pair0 = pair.clone()
# fx_out = gm(node, pair) node1, pair1 = gm(node0, pair0)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem)) print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
# test forward # test forward
with torch.no_grad(): with torch.no_grad():
@ -63,8 +63,8 @@ def _run_offload_codegen(rank):
# build model and input # build model and input
model = evoformer_base().cuda() model = evoformer_base().cuda()
node = torch.randn(1, 16, 32, 256).cuda() node = torch.randn(1, 100, 300, 256).cuda()
pair = torch.randn(1, 32, 32, 128).cuda() pair = torch.randn(1, 300, 300, 128).cuda()
# trace the module and replace codegen # trace the module and replace codegen
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})