mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.8 KiB
54 lines
1.8 KiB
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
import torch |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
from colossalai.core import global_context as gpc |
|
from colossalai.utils import get_current_device, synchronize |
|
|
|
|
|
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): |
|
"""Sends a tensor to the next member and recieves a tensor from the previous member. |
|
This function returns the recieved tensor from the previous member. |
|
|
|
:param tensor_send_next: Tensor sent to next member |
|
:param parallel_mode: Parallel group mode used in this communication |
|
:type tensor_send_next: :class:`torch.Tensor` |
|
:type parallel_mode: :class:`colossalai.context.ParallelMode` |
|
:return: The tensor recieved from the previous |
|
:rtype: :class:`torch.Tensor` |
|
""" |
|
buffer_shape = tensor_send_next.size() |
|
|
|
ops = [] |
|
current_rank = gpc.get_global_rank() |
|
|
|
tensor_recv_prev = torch.empty(buffer_shape, |
|
requires_grad=True, |
|
device=get_current_device(), |
|
dtype=tensor_send_next.dtype) |
|
|
|
# send to next rank |
|
send_next_op = torch.distributed.P2POp( |
|
torch.distributed.isend, tensor_send_next, |
|
gpc.get_next_global_rank(parallel_mode)) |
|
ops.append(send_next_op) |
|
|
|
# receive from prev rank |
|
recv_prev_op = torch.distributed.P2POp( |
|
torch.distributed.irecv, tensor_recv_prev, |
|
gpc.get_prev_global_rank(parallel_mode)) |
|
ops.append(recv_prev_op) |
|
|
|
if current_rank % 2 == 0: |
|
ops = ops[::-1] |
|
|
|
reqs = torch.distributed.batch_isend_irecv(ops) |
|
for req in reqs: |
|
req.wait() |
|
|
|
# To protect against race condition when using batch_isend_irecv(). |
|
synchronize() |
|
|
|
return tensor_recv_prev
|
|
|