ColossalAI/tests/test_elixir/test_chunk/test_fetcher.py

65 lines
1.8 KiB
Python

import copy
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.elixir.chunk import ChunkGroup
from colossalai.elixir.utils import seed_all
from colossalai.testing import run_on_environment_flag, spawn
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def check_gradient(ddp_model, my_model, cg: ChunkGroup):
for chunk in cg.fused_chunks:
cg.access_chunk(chunk)
for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()):
torch.cuda.synchronize()
print(f'checking parameter {name}')
assert_close(p0.grad.data, p1.data)
def exam_chunk_fetcher(group):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = copy.deepcopy(torch_model)
rank = dist.get_rank(group)
# get different data
seed_all(1001 + rank)
data = to_cuda(data_fn())
seed_all(1001, cuda_deterministic=True)
ddp_model = DDP(torch_model)
ddp_loss = ddp_model(**data)
ddp_loss.backward()
hook_model, cg = hook_transform(test_model, group)
my_loss = hook_model(**data)
my_loss.backward()
assert_close(ddp_loss, my_loss)
check_gradient(ddp_model, hook_model, cg)
print('private chunk fetcher is ok')
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_chunk_fetcher(group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
def test_chunk_fetcher(world_size):
spawn(run_dist, nprocs=world_size)
if __name__ == '__main__':
test_chunk_fetcher(world_size=2)
test_chunk_fetcher(world_size=2)