You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_fx/test_codegen/test_offload_codegen.py

186 lines
6.6 KiB

import copy
import pytest
import torch
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
class MyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = torch.nn.Linear(4, 4)
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear0(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
return x
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p.grad, gm_p.grad):
return False
return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
# test forward
non_fx_out = model(data)
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
# test backward
loss0 = non_fx_out.sum()
loss0.backward()
loss1 = fx_out.sum()
loss1.backward()
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
def _run_offload_codegen(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen()
graph.set_codegen(codegen)
# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
node.meta["activation_offload"] = [0, True, False]
if node.name == "linear1":
node.meta["activation_offload"] = [0, True, False]
if node.name == "linear2":
node.meta["activation_offload"] = [1, True, True]
if node.name == "linear4":
node.meta["activation_offload"] = [2, False, True]
if node.name == "linear5":
node.meta["activation_checkpoint"] = [0]
node.meta["activation_offload"] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
# assert we have all the components
code = graph.python_code("self").src
assert (
"def pack_hook_input(self, x):" in code
and "def unpack_hook(self, packed):" in code
and "def pack_hook_no_input(self, x):" in code
and "setattr(x, 'offload', True)" in code
and "setattr(linear3, 'offload', False)" in code
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code
and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)"
in code
)
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@rerun_if_address_is_in_use()
def test_act_ckpt_codegen():
spawn(_run_offload_codegen, 1)
def _run_offload_codegen_torch11(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
# replace a bound method of an object
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
node.meta["activation_offload"] = [0, True, False]
if node.name == "linear1":
node.meta["activation_offload"] = [0, True, False]
if node.name == "linear2":
node.meta["activation_offload"] = [1, True, True]
if node.name == "linear4":
node.meta["activation_offload"] = [2, False, True]
if node.name == "linear5":
node.meta["activation_checkpoint"] = [0]
node.meta["activation_offload"] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
# assert we have all the components
code = graph.python_code("self").src
assert (
"def pack_hook_input(self, x):" in code
and "def unpack_hook(self, packed):" in code
and "def pack_hook_no_input(self, x):" in code
and "setattr(x, 'offload', True)" in code
and "setattr(linear3, 'offload', False)" in code
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code
and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)"
in code
)
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
@rerun_if_address_is_in_use()
def test_act_ckpt_python_code_torch11():
spawn(_run_offload_codegen_torch11, 1)
if __name__ == "__main__":
_run_offload_codegen(0)