mirror of https://github.com/hpcaitech/ColossalAI
Jiarui Fang
3 years ago
committed by
GitHub
17 changed files with 52 additions and 20 deletions
@ -0,0 +1,32 @@
|
||||
import pytest |
||||
|
||||
import colossalai |
||||
from colossalai.utils.cuda import get_current_device |
||||
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity |
||||
from colossalai.utils import free_port |
||||
|
||||
from functools import partial |
||||
import torch.multiprocessing as mp |
||||
|
||||
|
||||
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): |
||||
frac1 = colo_device_memory_capacity(get_current_device()) |
||||
colo_set_process_memory_fraction(0.5) |
||||
frac2 = colo_device_memory_capacity(get_current_device()) |
||||
assert frac2 * 2 == frac1 |
||||
|
||||
|
||||
def run_dist(rank, world_size, port): |
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
@pytest.mark.parametrize("world_size", [4, 5]) |
||||
def test_memory_utils(world_size): |
||||
run_func = partial(run_dist, world_size=world_size, port=free_port()) |
||||
mp.spawn(run_func, nprocs=world_size) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_memory_utils(world_size=2) |
Loading…
Reference in new issue