diff --git a/chunk_codegen.py b/chunk_codegen.py index 1267f64cb..4ca33a4d5 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable try: from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin + from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size CODEGEN_AVAILABLE = True except: from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin @@ -18,6 +19,82 @@ else: __all__ = ['python_code_with_activation_checkpoint'] +def _get_meta_node_size(x): + x = x.meta['tensor_meta'] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x + + +def _get_output_node_size(n): + fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + return activation_size(fwd_out) + + +def _get_delete_node_size(user, user_to_last_uses): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete]) + return delete_size + return 0 + + +def _get_last_usr(nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses + + +def _estimate_inference_mem(gm: torch.fx.GraphModule): + act_memory = 0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + for node in gm.graph.nodes: + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += _get_meta_node_size(node) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + act_memory += calculate_fwd_tmp(node) + # act_memory += calculate_fwd_out(node) + act_memory += _get_output_node_size(node) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= calculate_fwd_tmp(node) + act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory_after_node_log.append(act_memory) + + act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + param_memory = parameter_size(gm) + return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + + +def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size): + node_size = 0 + param_size = 0 + for node in gm.graph.nodes: + node_size += calculate_fwd_tmp(node) + node_size += calculate_fwd_out(node) + param_size = parameter_size(gm) + return (node_size + param_size) / 1024**2, param_size / 1024**2 + + def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): new_shape = "[" for idx, i in enumerate(shape): @@ -342,7 +419,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + _estimate_inference_mem(meta_graph) # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): @@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE: class ChunkCodeGen(CodeGen): def __init__(self, meta_graph): super().__init__() + self.meta_graph = meta_graph self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: @@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 547b983a9..1ab7d958b 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -2,6 +2,7 @@ import copy import torch import torch.nn.functional as F import pytest +import torch.fx import torch.multiprocessing as mp from torch.fx import GraphModule from colossalai.fx import ColoTracer @@ -56,18 +57,15 @@ def _run_offload_codegen(rank): pair = torch.randn(1, 32, 32, 128).cuda() # trace the module and replace codegen - tracer = ColoTracer(trace_act_ckpt=True) - graph = tracer.trace(model) - gm_prop = torch.fx.GraphModule(model, graph) - interp = MetaInfoProp(gm_prop) + graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + + # now run it twice to get meta info in graph module, not necessary + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) - - # annotate the chunk part - # for node in graph.nodes: - # if node.name == "linear0": - # setattr(node, "activation_offload", [0, True, False]) - # if node.name == "linear1": - # setattr(node, "activation_offload", [0, True, False]) codegen = ChunkCodeGen(gm_prop) graph.set_codegen(codegen)