ColossalAI/tests/test_elixir/test_chunk/test_block.py

66 lines
2.0 KiB
Python

import torch
from colossalai.elixir.chunk import BlockSpec, MemoryPool, PrivateBlock, PublicBlock
from colossalai.testing import run_on_environment_flag
def test_block():
# test for public block
public_block = PublicBlock(123, torch.float16, 'cuda')
public_payload = public_block.payload
assert public_payload.numel() == 123
assert public_payload.dtype == torch.float16
assert public_payload.device.type == 'cuda'
assert public_payload.numel() * public_payload.element_size() == public_block.size_in_bytes
# test for private block
private_block = PrivateBlock(77, torch.float, 'cpu')
private_payload = private_block.payload
assert private_payload.numel() == 77
assert private_payload.dtype == torch.float
assert private_payload.device.type == 'cpu'
assert private_payload.numel() * private_payload.element_size() == private_block.size_in_bytes
print('test_block: ok')
def test_memory_pool():
mp = MemoryPool(device_type='cuda')
# allocate public blocks
mp.allocate_public_blocks(block_num=4)
# allocate private blocks
private_block_specs = [BlockSpec(5, torch.float), BlockSpec(81, torch.float16)]
mp.allocate_private_blocks(private_block_specs)
# test for public blocks
block0 = mp.pop_public_block()
assert block0 in mp.public_used_blocks
assert mp.public_used_count == 1
assert mp.public_free_count == 3
block1 = mp.pop_public_block()
assert block1 in mp.public_used_blocks
assert mp.public_used_count == 2
assert mp.public_free_count == 2
mp.free_public_block(block0)
mp.free_public_block(block1)
assert block0 in mp.public_free_blocks
assert block1 in mp.public_free_blocks
assert mp.public_used_count == 0
assert mp.public_free_count == 4
# test for private block
block0 = mp.get_private_block(5, torch.float)
assert block0.numel == 5
assert block0.dtype == torch.float
print('test_memory_pool: ok')
if __name__ == '__main__':
test_block()
test_memory_pool()