update test

pull/2364/head
oahzxl 2023-01-06 11:44:01 +08:00
parent efb1c64c30
commit 06a5355d98
1 changed files with 50 additions and 57 deletions

View File

@ -1,46 +1,20 @@
import copy
import torch
import torch.nn.functional as F
import pytest import pytest
import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.autochunk.chunk_codegen import ChunkCodeGen
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.evoformer.evoformer import evoformer_base
from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad):
return False
return True
def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data):
return False
return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node0 = node.clone()
# pair0 = pair.clone()
# model.graph(node0, pair0, now_mem)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2 now_mem = torch.cuda.memory_allocated() / 1024**2
with torch.no_grad(): with torch.no_grad():
@ -49,28 +23,38 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
gm(node1, pair1) gm(node1, pair1)
new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_now_mem = torch.cuda.memory_allocated() / 1024**2
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("gm now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) print(
"autochunk now mem:%.2f max mem:%.2f"
% (new_now_mem - now_mem, new_max_mem - now_mem)
)
# test forward # test forward
with torch.no_grad(): with torch.no_grad():
non_fx_out = model(node, pair) non_fx_out = model(node, pair)
fx_out = gm(node, pair) fx_out = gm(node, pair)
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0])) assert torch.allclose(
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1])) non_fx_out[0], fx_out[0], atol=1e-4
), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
# test barckward torch.abs(non_fx_out[0] - fx_out[0])
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum() )
# loss0.backward() assert torch.allclose(
# loss1 = fx_out[0].sum() + fx_out[1].sum() non_fx_out[1], fx_out[1], atol=1e-4
# loss1.backward() ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
# assert _is_all_param_close(model, gm) torch.abs(non_fx_out[1] - fx_out[1])
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" )
def _run_offload_codegen(rank): def _run_offload_codegen(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input # build model and input
model = evoformer_base().cuda() model = evoformer_base().cuda()
@ -78,15 +62,25 @@ def _run_offload_codegen(rank):
pair = torch.randn(1, 300, 300, 128).cuda() pair = torch.randn(1, 300, 300, 128).cuda()
# trace the module and replace codegen # trace the module and replace codegen
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))}) graph = ColoTracer().trace(
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
},
)
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop) interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) interp.propagate(
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
# now run it twice to get meta info in graph module, not necessary # now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph) gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm) interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) interp.propagate(
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
codegen = ChunkCodeGen(gm_prop) codegen = ChunkCodeGen(gm_prop)
graph.set_codegen(codegen) graph.set_codegen(codegen)
@ -94,15 +88,14 @@ def _run_offload_codegen(rank):
gm.recompile() gm.recompile()
# assert we have all the components # assert we have all the components
code = graph.python_code("self").src # code = graph.python_code("self").src
print(code) # print(code)
_test_fwd_and_bwd(model, gm, node, pair) _test_fwd(model, gm, node, pair)
gpc.destroy() gpc.destroy()
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_autochunk():
def test_act_ckpt_codegen():
mp.spawn(_run_offload_codegen, nprocs=1) mp.spawn(_run_offload_codegen, nprocs=1)