Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

54 lines
1.5 KiB

from functools import partial
from typing import List
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
disable_existing_loggers()
world_size = 4
CONFIG = dict(parallel=dict(pipeline=world_size))
torch.manual_seed(123)
def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False)
rank = gpc.get_local_rank(ParallelMode.PIPELINE)
if rank == 0:
obj = [torch.randn(3,)]
_send_object(obj, 1)
if rank == 1:
_recv_object(0)
if rank == 2:
_recv_object(3)
if rank == 3:
obj = [torch.randn(3,)]
_send_object(obj, 2)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
disable_existing_loggers()
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_object_list_p2p()