mirror of https://github.com/hpcaitech/ColossalAI
add meta
parent
820ea4d056
commit
f8aeecef46
|
@ -366,6 +366,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
if CODEGEN_AVAILABLE:
|
if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
class ChunkCodeGen(CodeGen):
|
class ChunkCodeGen(CodeGen):
|
||||||
|
def __init__(self, meta_graph):
|
||||||
|
super().__init__()
|
||||||
|
self.meta_node = list(meta_graph.graph.nodes)
|
||||||
|
|
||||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||||
free_vars: List[str] = []
|
free_vars: List[str] = []
|
||||||
|
|
|
@ -9,6 +9,8 @@ import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from evoformer.evoformer import evoformer_base
|
from evoformer.evoformer import evoformer_base
|
||||||
from chunk_codegen import ChunkCodeGen
|
from chunk_codegen import ChunkCodeGen
|
||||||
with_codegen = True
|
with_codegen = True
|
||||||
|
@ -56,9 +58,10 @@ def _run_offload_codegen(rank):
|
||||||
# trace the module and replace codegen
|
# trace the module and replace codegen
|
||||||
tracer = ColoTracer(trace_act_ckpt=True)
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
graph = tracer.trace(model)
|
graph = tracer.trace(model)
|
||||||
# codegen = ChunkCodeGen()
|
gm_prop = torch.fx.GraphModule(model, graph)
|
||||||
# graph.set_codegen(codegen)
|
interp = MetaInfoProp(gm_prop)
|
||||||
|
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||||
|
|
||||||
# annotate the chunk part
|
# annotate the chunk part
|
||||||
# for node in graph.nodes:
|
# for node in graph.nodes:
|
||||||
# if node.name == "linear0":
|
# if node.name == "linear0":
|
||||||
|
@ -66,7 +69,9 @@ def _run_offload_codegen(rank):
|
||||||
# if node.name == "linear1":
|
# if node.name == "linear1":
|
||||||
# setattr(node, "activation_offload", [0, True, False])
|
# setattr(node, "activation_offload", [0, True, False])
|
||||||
|
|
||||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
codegen = ChunkCodeGen(gm_prop)
|
||||||
|
# graph.set_codegen(codegen)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# assert we have all the components
|
# assert we have all the components
|
||||||
|
|
Loading…
Reference in New Issue