@ -1,76 +1,60 @@
import copy
import torch
import torch . nn . functional as F
import pytest
import torch
import torch . fx
import torch . multiprocessing as mp
from torch . fx import GraphModule
from colossalai . fx import ColoTracer
import colossalai
from colossalai . utils import free_port
from colossalai . autochunk . chunk_codegen import ChunkCodeGen
from colossalai . core import global_context as gpc
from colossalai . fx import ColoTracer
from colossalai . fx . graph_module import ColoGraphModule
from colossalai . fx . passes . meta_info_prop import MetaInfoProp , TensorMetadata
from colossalai . fx . passes . meta_info_prop import MetaInfoProp
from colossalai . fx . profiler import MetaTensor
from colossalai . utils import free_port
from tests . test_autochunk . evoformer . evoformer import evoformer_base
from . . . colossalai . autochunk . chunk_codegen import ChunkCodeGen
with_codegen = True
def _is_all_gradient_close ( m : torch . nn . Module , gm : GraphModule ) - > bool :
for m_p , gm_p in zip ( m . parameters ( ) , gm . parameters ( ) ) :
if m_p . grad is not None and not torch . allclose ( m_p . grad , gm_p . grad ) :
return False
return True
def _is_all_param_close ( m : torch . nn . Module , gm : GraphModule ) - > bool :
for m_p , gm_p in zip ( m . parameters ( ) , gm . parameters ( ) ) :
if m_p . grad is not None and not torch . allclose ( m_p . data , gm_p . data ) :
return False
return True
def _test_fwd_and_bwd ( model : torch . nn . Module , gm : ColoGraphModule , node , pair ) :
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node0 = node.clone()
# pair0 = pair.clone()
# model.graph(node0, pair0, now_mem)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
def _test_fwd ( model : torch . nn . Module , gm : ColoGraphModule , node , pair ) :
torch . cuda . reset_peak_memory_stats ( )
now_mem = torch . cuda . memory_allocated ( ) / 1024 * * 2
with torch . no_grad ( ) :
node1 = node . clone ( )
pair1 = pair . clone ( )
gm ( node1 , pair1 )
gm ( node1 , pair1 )
new_now_mem = torch . cuda . memory_allocated ( ) / 1024 * * 2
new_max_mem = torch . cuda . max_memory_allocated ( ) / 1024 * * 2
print ( " gm now: %.2f max: %.2f " % ( new_now_mem - now_mem , new_max_mem - now_mem ) )
print (
" autochunk now mem: %.2f max mem: %.2f "
% ( new_now_mem - now_mem , new_max_mem - now_mem )
)
# test forward
with torch . no_grad ( ) :
non_fx_out = model ( node , pair )
fx_out = gm ( node , pair )
assert torch . allclose ( non_fx_out [ 0 ] , fx_out [ 0 ] , atol = 1e-4 ) , " fx_out doesn ' t comply with original output, diff is %.2e " % torch . mean ( torch . abs ( non_fx_out [ 0 ] - fx_out [ 0 ] ) )
assert torch . allclose ( non_fx_out [ 1 ] , fx_out [ 1 ] , atol = 1e-4 ) , " fx_out doesn ' t comply with original output, diff is %.2e " % torch . mean ( torch . abs ( non_fx_out [ 1 ] - fx_out [ 1 ] ) )
# test barckward
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum( )
# loss0.backward()
# loss1 = fx_out[0].sum() + fx_out[1].sum()
# loss1.backward()
# assert _is_all_param_close(model, gm )
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
assert torch . allclose (
non_fx_out [ 0 ] , fx_out [ 0 ] , atol = 1e-4
) , " fx_out doesn ' t comply with original output, diff is %.2e " % torch . mean (
torch . abs ( non_fx_out [ 0 ] - fx_out [ 0 ] )
)
assert torch . allclose (
non_fx_out [ 1 ] , fx_out [ 1 ] , atol = 1e-4
) , " fx_out doesn ' t comply with original output, diff is %.2e " % torch . mean (
torch . abs ( non_fx_out [ 1 ] - fx_out [ 1 ] )
)
def _run_offload_codegen ( rank ) :
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai . launch ( config = { } , rank = rank , world_size = 1 , host = ' localhost ' , port = free_port ( ) , backend = ' nccl ' )
colossalai . launch (
config = { } ,
rank = rank ,
world_size = 1 ,
host = " localhost " ,
port = free_port ( ) ,
backend = " nccl " ,
)
# build model and input
model = evoformer_base ( ) . cuda ( )
@ -78,15 +62,25 @@ def _run_offload_codegen(rank):
pair = torch . randn ( 1 , 300 , 300 , 128 ) . cuda ( )
# trace the module and replace codegen
graph = ColoTracer ( ) . trace ( model , meta_args = { ' node ' : node . to ( torch . device ( ' meta ' ) ) , ' pair ' : pair . to ( torch . device ( ' meta ' ) ) } )
gm_prop = torch . fx . symbolic_trace ( model ) # must use symbolic_trace
interp = MetaInfoProp ( gm_prop )
interp . propagate ( MetaTensor ( node , fake_device = ' cuda:0 ' ) , MetaTensor ( pair , fake_device = ' cuda:0 ' ) )
graph = ColoTracer ( ) . trace (
model ,
meta_args = {
" node " : node . to ( torch . device ( " meta " ) ) ,
" pair " : pair . to ( torch . device ( " meta " ) ) ,
} ,
)
gm_prop = torch . fx . symbolic_trace ( model ) # must use symbolic_trace
interp = MetaInfoProp ( gm_prop )
interp . propagate (
MetaTensor ( node , fake_device = " cuda:0 " ) , MetaTensor ( pair , fake_device = " cuda:0 " )
)
# now run it twice to get meta info in graph module, not necessary
gm = torch . fx . GraphModule ( model , graph )
interp = MetaInfoProp ( gm )
interp . propagate ( MetaTensor ( node , fake_device = ' cuda:0 ' ) , MetaTensor ( pair , fake_device = ' cuda:0 ' ) )
interp . propagate (
MetaTensor ( node , fake_device = " cuda:0 " ) , MetaTensor ( pair , fake_device = " cuda:0 " )
)
codegen = ChunkCodeGen ( gm_prop )
graph . set_codegen ( codegen )
@ -94,15 +88,14 @@ def _run_offload_codegen(rank):
gm . recompile ( )
# assert we have all the components
code = graph . python_code ( " self " ) . src
print ( code )
# code = graph.python_code("self"). src
# print(code )
_test_fwd_and_bwd ( model , gm , node , pair )
_test_fwd ( model , gm , node , pair )
gpc . destroy ( )
@pytest . mark . skipif ( not with_codegen , reason = ' torch version is lower than 1.12.0 ' )
def test_act_ckpt_codegen ( ) :
def test_autochunk ( ) :
mp . spawn ( _run_offload_codegen , nprocs = 1 )