mirror of https://github.com/hpcaitech/ColossalAI
131 lines
2.9 KiB
Python
131 lines
2.9 KiB
Python
import os
|
|
from functools import partial
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.elixir.chunk import Chunk, MemoryPool
|
|
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
|
|
from colossalai.elixir.utils import init_distributed
|
|
from colossalai.testing import run_on_environment_flag
|
|
|
|
|
|
def exam_fifo(nproc, group):
|
|
mp = MemoryPool('cuda')
|
|
mp.allocate(public_block_number=1)
|
|
c0 = Chunk(mp, 1024, torch.float, group)
|
|
c1 = Chunk(mp, 1024, torch.float, group)
|
|
c2 = Chunk(mp, 1024, torch.float, group)
|
|
|
|
sdl = FIFOScheduler()
|
|
sdl.reset()
|
|
|
|
sdl.add(c0)
|
|
sdl.add(c1)
|
|
sdl.add(c2)
|
|
sdl.add(c0) # nothing happens here
|
|
assert sdl.top() == c0
|
|
|
|
sdl.remove(c0)
|
|
assert sdl.top() == c1, f'{sdl.top()}'
|
|
sdl.remove(c0)
|
|
assert sdl.top() == c1, f'{sdl.top()}'
|
|
|
|
sdl.add(c0)
|
|
assert sdl.top() == c1
|
|
sdl.remove(c1)
|
|
assert sdl.top() == c2
|
|
sdl.remove(c2)
|
|
assert sdl.top() == c0
|
|
|
|
|
|
def exam_prefetch(nproc, group):
|
|
mp = MemoryPool('cuda')
|
|
mp.allocate()
|
|
c0 = Chunk(mp, 1024, torch.float, group)
|
|
c1 = Chunk(mp, 1024, torch.float, group)
|
|
c2 = Chunk(mp, 1024, torch.float, group)
|
|
|
|
chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]]
|
|
|
|
sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
|
|
print(sdl.next_step_dict)
|
|
sdl.reset()
|
|
|
|
sdl.step()
|
|
sdl.add(c0)
|
|
assert sdl.top() == c0
|
|
|
|
sdl.step()
|
|
sdl.add(c1)
|
|
assert sdl.top() == c1
|
|
|
|
sdl.step()
|
|
sdl.add(c2)
|
|
assert sdl.top() == c2
|
|
|
|
sdl.remove(c0)
|
|
sdl.step()
|
|
sdl.add(c0)
|
|
assert sdl.top() == c2
|
|
|
|
sdl.remove(c0)
|
|
sdl.step()
|
|
sdl.add(c0)
|
|
assert sdl.top() == c0
|
|
sdl.remove(c0) # notice here
|
|
|
|
sdl.remove(c1)
|
|
sdl.step()
|
|
sdl.add(c1)
|
|
assert sdl.top() == c1
|
|
|
|
sdl.remove(c2)
|
|
sdl.step()
|
|
sdl.add(c2)
|
|
assert sdl.top() == c1
|
|
|
|
sdl.remove(c2)
|
|
sdl.step()
|
|
sdl.add(c2)
|
|
assert sdl.top() == c2
|
|
sdl.remove(c2) # notice here
|
|
sdl.add(c0) # notice here
|
|
|
|
sdl.remove(c1)
|
|
sdl.step()
|
|
sdl.add(c1)
|
|
assert sdl.top() == c1
|
|
sdl.remove(c1) # notice here
|
|
|
|
sdl.remove(c0)
|
|
sdl.step()
|
|
sdl.add(c0)
|
|
assert sdl.top() == c0
|
|
|
|
sdl.remove(c0)
|
|
sdl.clear()
|
|
|
|
|
|
def run_dist(rank, world_size):
|
|
os.environ['RANK'] = str(rank)
|
|
os.environ['LOCAL_RANK'] = str(rank)
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = str(29512)
|
|
init_distributed()
|
|
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
|
|
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@run_on_environment_flag('ELX')
|
|
def test_chunk_scheduler(world_size=1):
|
|
run_func = partial(run_dist, world_size=world_size)
|
|
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_chunk_scheduler()
|