ColossalAI/tests/test_elixir/test_chunk/test_chunk.py

156 lines
4.7 KiB
Python

import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_functions(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
chunk = Chunk(mp, 1024, torch.float, group)
chunk.l2_norm_flag = True
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
# check function: get_cpu_copy
cpu_copys = chunk.get_cpu_copy()
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
assert t_cpu.device.type == 'cpu'
assert torch.equal(t_gpu.cpu(), t_cpu)
# check function: access_chunk
block = mp.get_public_block()
chunk.access_chunk(block)
assert chunk.is_replica
assert chunk.scatter_check
check_tensors()
# check function: release_chunk
chunk.optim_sync_flag = False
block = chunk.release_chunk()
assert block in mp.public_used_blocks
assert chunk.is_replica is False
assert chunk.optim_sync_flag is True
# check function: access_chunk after release_chunk
chunk.access_chunk(block)
check_tensors()
# check function: reduce_chunk
norm = block.payload.float().norm(2)**2
chunk.reduce_chunk()
assert chunk.is_replica is False
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
test_norm = torch.Tensor([chunk.l2_norm]).cuda()
dist.all_reduce(test_norm)
assert torch.allclose(norm, test_norm)
torch.cuda.synchronize()
print('chunk functions are ok')
def exam_chunk_states(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
private = [BlockRequire(1024, torch.float)]
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private)
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
chunk.access_chunk()
assert chunk.is_replica
check_tensors()
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
chunk.tensor_trans_state(a, TensorState.COMPUTE)
assert chunk.tensor_state_cnter[TensorState.HOLD] == 3
assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
tensor_list = [a, b, c, d]
for t in tensor_list:
chunk.tensor_trans_state(t, TensorState.COMPUTE)
chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD)
chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE)
assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert chunk.reduce_check
torch.cuda.synchronize()
print('chunk states are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_functions(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_functions(world_size=4)