mirror of https://github.com/hpcaitech/ColossalAI
99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
import os
|
|
from functools import partial
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
|
|
from colossalai.elixir.utils import init_distributed
|
|
from colossalai.testing import run_on_environment_flag
|
|
|
|
|
|
def exam_chunk_group_functions(nproc, group):
|
|
a = torch.randn(3, 64, device='cuda')
|
|
copy_a = a.clone()
|
|
b = torch.randn(2, 32, device='cuda')
|
|
copy_b = b.clone()
|
|
c = torch.randn(256, device='cuda')
|
|
copy_c = c.clone()
|
|
d = torch.randn(2, 2, 64, device='cuda')
|
|
copy_d = d.clone()
|
|
e = torch.randn(2, 33, device='cuda')
|
|
copy_e = e.clone()
|
|
|
|
mp = MemoryPool('cuda')
|
|
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
|
|
cg = ChunkGroup(rcache=mp)
|
|
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
|
|
c1 = cg.allocate_chunk([c], 256, torch.float, group)
|
|
c2 = cg.allocate_chunk([d], 256, torch.float, group)
|
|
|
|
fused_config = dict(rcache_fused=True)
|
|
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
|
|
|
|
def check_chunk_0():
|
|
assert torch.equal(a, copy_a)
|
|
assert torch.equal(b, copy_b)
|
|
|
|
def check_chunk_1():
|
|
assert torch.equal(c, copy_c)
|
|
|
|
def check_chunk_2():
|
|
assert torch.equal(d, copy_d)
|
|
|
|
def check_chunk_3():
|
|
assert torch.equal(e, copy_e)
|
|
|
|
# check tensors_to_chunks
|
|
chunks = cg.tensors_to_chunks([e, a])
|
|
assert chunks[0] == c0
|
|
assert chunks[1] == c3
|
|
# check access_chunk for unfused chunks
|
|
cg.access_chunk(c0)
|
|
cg.access_chunk(c1)
|
|
check_chunk_0()
|
|
check_chunk_1()
|
|
assert not cg.rcache_enough_check(c2)
|
|
assert cg.rcache_enough_check(c3)
|
|
# check access_chunk for fused chunks
|
|
cg.access_chunk(c3)
|
|
check_chunk_3()
|
|
# check release_chunk for unfused chunks
|
|
cg.release_chunk(c1)
|
|
assert cg.rcache_enough_check(c2)
|
|
# check access_chunk
|
|
cg.access_chunk(c2)
|
|
check_chunk_2()
|
|
|
|
cg.tensor_trans_state(e, TensorState.COMPUTE)
|
|
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
|
|
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
|
|
cg.reduce_chunk(c3)
|
|
assert not c3.is_replica
|
|
|
|
torch.cuda.synchronize()
|
|
print('chunk group functions are ok')
|
|
|
|
|
|
def run_dist(rank, world_size):
|
|
os.environ['RANK'] = str(rank)
|
|
os.environ['LOCAL_RANK'] = str(rank)
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = str(29512)
|
|
init_distributed()
|
|
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
|
@run_on_environment_flag('ELX')
|
|
def test_chunk_group(world_size):
|
|
run_func = partial(run_dist, world_size=world_size)
|
|
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_chunk_group(world_size=2)
|