mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.8 KiB
60 lines
1.8 KiB
1 year ago
|
import pytest
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.cluster import ProcessGroupMesh
|
||
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||
|
from colossalai.utils import get_current_device
|
||
|
|
||
|
|
||
|
def check_p2p_communication():
|
||
|
pg_mesh = ProcessGroupMesh(2)
|
||
|
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||
|
p2p = PipelineP2PCommunication(stage_manager)
|
||
|
|
||
|
rank = dist.get_rank()
|
||
|
|
||
|
tensor = torch.ones(1, device=get_current_device())
|
||
|
|
||
|
if rank == 0:
|
||
|
p2p.send_forward(tensor)
|
||
|
p2p.send_forward([tensor])
|
||
|
p2p.send_forward({'tensor': tensor})
|
||
|
else:
|
||
|
obj = p2p.recv_forward()
|
||
|
assert torch.equal(obj, tensor)
|
||
|
obj = p2p.recv_forward()
|
||
|
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||
|
obj = p2p.recv_forward()
|
||
|
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor)
|
||
|
|
||
|
if rank == 1:
|
||
|
p2p.send_backward(tensor)
|
||
|
p2p.send_backward([tensor])
|
||
|
p2p.send_backward({'tensor': tensor})
|
||
|
else:
|
||
|
obj = p2p.recv_backward()
|
||
|
assert torch.equal(obj, tensor)
|
||
|
obj = p2p.recv_backward()
|
||
|
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||
|
obj = p2p.recv_backward()
|
||
|
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor)
|
||
|
|
||
|
|
||
|
def run_dist(rank, world_size, port):
|
||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||
|
check_p2p_communication()
|
||
|
|
||
|
|
||
|
@pytest.mark.dist
|
||
|
@rerun_if_address_is_in_use()
|
||
|
def test_pipeline_p2p():
|
||
|
spawn(run_dist, 2)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_pipeline_p2p()
|