2023-06-28 09:12:19 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
import colossalai
|
2024-01-09 02:20:05 +00:00
|
|
|
from colossalai.accelerator import get_accelerator
|
2023-06-28 09:12:19 +00:00
|
|
|
from colossalai.cluster import ProcessGroupMesh
|
2024-01-08 07:37:27 +00:00
|
|
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
2023-06-28 09:12:19 +00:00
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
WORLD_SIZE = 2
|
|
|
|
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
def check_p2p_communication():
|
2023-12-22 02:44:00 +00:00
|
|
|
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
2023-06-28 09:12:19 +00:00
|
|
|
stage_manager = PipelineStageManager(pg_mesh, 0)
|
|
|
|
p2p = PipelineP2PCommunication(stage_manager)
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
2024-01-09 02:20:05 +00:00
|
|
|
tensor = torch.ones(1, device=get_accelerator().get_current_device())
|
2023-12-22 02:44:00 +00:00
|
|
|
data = [
|
|
|
|
"tensor",
|
|
|
|
tensor,
|
|
|
|
[tensor],
|
|
|
|
{"tensor": tensor},
|
|
|
|
]
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
if rank == 0:
|
2023-12-22 02:44:00 +00:00
|
|
|
for obj in data:
|
|
|
|
p2p.send_forward(obj)
|
|
|
|
for i in range(len(data)):
|
2024-01-03 03:34:49 +00:00
|
|
|
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
2023-12-22 02:44:00 +00:00
|
|
|
assert recv_obj == data[-(i + 1)]
|
|
|
|
elif rank == 1:
|
|
|
|
for obj in data:
|
|
|
|
recv_obj = p2p.recv_forward()
|
|
|
|
assert recv_obj == obj
|
|
|
|
for i in range(len(data)):
|
|
|
|
p2p.send_backward(data[-(i + 1)])
|
|
|
|
recv_obj = p2p.recv_forward()
|
|
|
|
assert recv_obj == data[i]
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
if rank == 1:
|
2023-12-22 02:44:00 +00:00
|
|
|
for obj in data:
|
|
|
|
p2p.send_backward(obj)
|
|
|
|
for i in range(len(data)):
|
2024-01-03 03:34:49 +00:00
|
|
|
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
2023-12-22 02:44:00 +00:00
|
|
|
assert recv_obj == data[-(i + 1)]
|
|
|
|
elif rank == 0:
|
|
|
|
for obj in data:
|
|
|
|
recv_obj = p2p.recv_backward()
|
|
|
|
assert recv_obj == obj
|
|
|
|
for i in range(len(data)):
|
|
|
|
recv_obj = p2p.recv_backward()
|
|
|
|
p2p.send_forward(data[-(i + 1)])
|
|
|
|
assert recv_obj == data[i]
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
|
recv_obj = p2p.send_forward_recv_backward(
|
|
|
|
tensor,
|
|
|
|
send_metadata=False,
|
2024-01-08 07:37:27 +00:00
|
|
|
metadata_recv=create_send_metadata(tensor),
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
|
|
|
assert recv_obj == tensor
|
|
|
|
elif rank == 1:
|
2024-01-08 07:37:27 +00:00
|
|
|
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
2023-12-22 02:44:00 +00:00
|
|
|
assert recv_obj == tensor
|
|
|
|
p2p.send_backward(tensor, send_metadata=False)
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
2023-09-19 06:20:26 +00:00
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
2023-06-28 09:12:19 +00:00
|
|
|
check_p2p_communication()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_pipeline_p2p():
|
2023-12-22 02:44:00 +00:00
|
|
|
spawn(run_dist, WORLD_SIZE)
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2023-06-28 09:12:19 +00:00
|
|
|
test_pipeline_p2p()
|