From f8aeecef46461ff574f51982d03310fa8c57888e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 3 Nov 2022 14:33:35 +0800 Subject: [PATCH] add meta --- chunk_codegen.py | 3 +++ chunk_codegen_run.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index cb2a3a8a9..1f336eb2b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -366,6 +366,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if CODEGEN_AVAILABLE: class ChunkCodeGen(CodeGen): + def __init__(self, meta_graph): + super().__init__() + self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 7667fa691..b875b6308 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -9,6 +9,8 @@ import colossalai from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.profiler import MetaTensor from evoformer.evoformer import evoformer_base from chunk_codegen import ChunkCodeGen with_codegen = True @@ -56,9 +58,10 @@ def _run_offload_codegen(rank): # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(model) - # codegen = ChunkCodeGen() - # graph.set_codegen(codegen) - + gm_prop = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm_prop) + interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) + # annotate the chunk part # for node in graph.nodes: # if node.name == "linear0": @@ -66,7 +69,9 @@ def _run_offload_codegen(rank): # if node.name == "linear1": # setattr(node, "activation_offload", [0, True, False]) - gm = ColoGraphModule(copy.deepcopy(model), graph) + codegen = ChunkCodeGen(gm_prop) + # graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) gm.recompile() # assert we have all the components