2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.legacy.core import global_context as gpc
|
2021-10-28 16:21:23 +00:00
|
|
|
from colossalai.utils import get_current_device, synchronize
|
|
|
|
|
|
|
|
|
2022-04-25 05:41:43 +00:00
|
|
|
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
"""Sends a tensor to the next member and receives a tensor from the previous member.
|
|
|
|
This function returns the received tensor from the previous member.
|
|
|
|
|
|
|
|
Args:
|
2022-04-25 05:41:43 +00:00
|
|
|
tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
|
|
|
|
parallel_mode (ParallelMode): Parallel group mode used in this communication
|
2022-03-25 05:02:39 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.Tensor`: The tensor received from the previous.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
buffer_shape = tensor_send_next.size()
|
|
|
|
|
|
|
|
ops = []
|
|
|
|
current_rank = gpc.get_global_rank()
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
tensor_recv_prev = torch.empty(
|
|
|
|
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
|
|
|
|
)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# send to next rank
|
2023-09-19 06:20:26 +00:00
|
|
|
send_next_op = torch.distributed.P2POp(
|
|
|
|
torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode)
|
|
|
|
)
|
2021-10-28 16:21:23 +00:00
|
|
|
ops.append(send_next_op)
|
|
|
|
|
|
|
|
# receive from prev rank
|
2023-09-19 06:20:26 +00:00
|
|
|
recv_prev_op = torch.distributed.P2POp(
|
|
|
|
torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode)
|
|
|
|
)
|
2021-10-28 16:21:23 +00:00
|
|
|
ops.append(recv_prev_op)
|
|
|
|
|
|
|
|
if current_rank % 2 == 0:
|
|
|
|
ops = ops[::-1]
|
|
|
|
|
|
|
|
reqs = torch.distributed.batch_isend_irecv(ops)
|
|
|
|
for req in reqs:
|
|
|
|
req.wait()
|
|
|
|
|
|
|
|
# To protect against race condition when using batch_isend_irecv().
|
|
|
|
synchronize()
|
|
|
|
|
|
|
|
return tensor_recv_prev
|