ColossalAI/tests/test_elixir/test_chunk/test_group.py

99 lines
2.9 KiB
Python

import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_group_functions(nproc, group):
a = torch.randn(3, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 32, device='cuda')
copy_b = b.clone()
c = torch.randn(256, device='cuda')
copy_c = c.clone()
d = torch.randn(2, 2, 64, device='cuda')
copy_d = d.clone()
e = torch.randn(2, 33, device='cuda')
copy_e = e.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
cg = ChunkGroup(rcache=mp)
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
c1 = cg.allocate_chunk([c], 256, torch.float, group)
c2 = cg.allocate_chunk([d], 256, torch.float, group)
fused_config = dict(rcache_fused=True)
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
def check_chunk_0():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
def check_chunk_1():
assert torch.equal(c, copy_c)
def check_chunk_2():
assert torch.equal(d, copy_d)
def check_chunk_3():
assert torch.equal(e, copy_e)
# check tensors_to_chunks
chunks = cg.tensors_to_chunks([e, a])
assert chunks[0] == c0
assert chunks[1] == c3
# check access_chunk for unfused chunks
cg.access_chunk(c0)
cg.access_chunk(c1)
check_chunk_0()
check_chunk_1()
assert not cg.rcache_enough_check(c2)
assert cg.rcache_enough_check(c3)
# check access_chunk for fused chunks
cg.access_chunk(c3)
check_chunk_3()
# check release_chunk for unfused chunks
cg.release_chunk(c1)
assert cg.rcache_enough_check(c2)
# check access_chunk
cg.access_chunk(c2)
check_chunk_2()
cg.tensor_trans_state(e, TensorState.COMPUTE)
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
cg.reduce_chunk(c3)
assert not c3.is_replica
torch.cuda.synchronize()
print('chunk group functions are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_group(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_group(world_size=2)