from typing import Any, Dict, List import torch import torch.fx import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.utils import flat_list from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx.profiler import MetaTensor from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace def assert_codegen_run( model: Any, meta_args: List, concrete_args: List = None, max_memory: int = None, print_mem: bool = False, print_progress: bool = False, print_code: bool = False, ) -> List[Dict]: if concrete_args is None: concrete_args = [] # trace the meta graph and setup codegen meta_graph = symbolic_trace( model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, concrete_args={k: v for k, v in concrete_args}, ) interp = MetaInfoProp(meta_graph) meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] interp.propagate(*meta_tensors) codegen = AutoChunkCodeGen( meta_graph, max_memory=max_memory, print_mem=print_mem, print_progress=print_progress, ) chunks = codegen.chunk_infos # trace and recompile # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer graph = ColoTracer().trace( model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, concrete_args={k: v for k, v in concrete_args}, ) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph, ckpt_codegen=False) gm.recompile() # assert chunk in code code = graph.python_code("self").src if print_code: print(code) assert "chunk_result = None; chunk_size = None;" in code # assert result inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] model.cuda() with torch.no_grad(): out_gm = gm(*inputs) out_model = model(*inputs) out_gm = flat_list(out_gm) out_model = flat_list(out_model) for out_gm_i, out_model_i in zip(out_gm, out_model): assert torch.allclose(out_gm_i, out_model_i, atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( torch.abs(out_gm_i - out_model_i)) return chunks def run_test( rank: int, data_args: tuple, max_memory: int, get_model: Any, get_data: Any, print_code: bool, print_mem: bool, print_progress: bool, get_chunk_target: Any = None, ) -> None: # launch colossalai colossalai.launch( config={}, rank=rank, world_size=1, host="localhost", port=free_port(), backend="nccl", ) # build model and input model = get_model() meta_args, concrete_args = get_data(*data_args) chunks = assert_codegen_run( model, meta_args=meta_args, concrete_args=concrete_args, max_memory=max_memory, print_code=print_code, print_mem=print_mem, print_progress=print_progress, ) if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), )