#!/usr/bin/env python # -*- encoding: utf-8 -*- # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication import operator from functools import reduce from typing import List, Tuple, Union import torch import torch.distributed as dist from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.common import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks TensorShape = Union[torch.Size, List[int], Tuple[int]] def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: """get the exact tensor shape when communicating and return whether the tensor is a chunk Args: tensor_shape (:class:`torch.Size`): shape of tensor chunk_tensor (bool, optional): whether to chunk tensor, defaults to False Returns: Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor """ if chunk_tensor: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) if tensor_chunk_shape % tensor_parallel_world_size == 0: tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size else: tensor_chunk_shape = tensor_shape chunk_tensor = False else: tensor_chunk_shape = tensor_shape return tensor_chunk_shape, chunk_tensor def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) return buffer_recv, recv_split buffer_recv = [] for recv_shape in recv_shapes: recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) buffer_recv.append(tensor_recv) return buffer_recv, recv_split def process_object_to_send(object_send, scatter_gather_tensors): if isinstance(object_send, torch.Tensor): send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1] if send_split: object_send = split_tensor_into_1d_equal_chunks(object_send) return object_send object_send_list = [] for tensor_send in object_send: send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] if send_split: object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) else: object_send_list.append(tensor_send) object_send = tuple(object_send_list) return object_send def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): if isinstance(obj, torch.Tensor): op_to_add = dist.P2POp(comm_op, obj, comm_rank) ops_queue.append(op_to_add) else: for tensor_to_comm in obj: op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) ops_queue.append(op_to_add) def _communicate( object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, recv_prev: bool = False, recv_next: bool = False, recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, recv_next_shape: Union[torch.Size, List[torch.Size]] = None, prev_rank: int = None, next_rank: int = None, dtype: torch.dtype = None, scatter_gather_tensors: bool = False, ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """ Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other communication methods that are used in pipeline schedule. Takes the following arguments: object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if set to None). object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if set to None). recv_prev (bool): boolean for whether tensor should be received from previous rank. recv_next (bool): boolean for whether tensor should be received from next rank. recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defualts to None. recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defualts to None. prev_rank (int): the rank of the previous pipeline stage, defualts to None, next_rank (int): the rank of the next pipeline stage, defualts to None, dtype (torch.dtype): data type of intermediate buffers, defaults to None scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False Returns: Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next """ # Create placeholder tensors for receive in forward and backward directions # if needed. tensor_recv_prev = None tensor_recv_next = None if recv_prev: assert recv_prev_shape is not None tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( recv_prev_shape, dtype, scatter_gather_tensors ) if recv_next: assert recv_next_shape is not None tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( recv_next_shape, dtype, scatter_gather_tensors ) if object_send_prev is not None or recv_prev: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) if object_send_next is not None or recv_next: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) if object_send_prev is not None: object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) if object_send_next is not None: object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) ops = [] if object_send_prev is not None: filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) if tensor_recv_prev is not None: filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) if tensor_recv_next is not None: filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) if object_send_next is not None: filling_ops_queue(object_send_next, dist.isend, next_rank, ops) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() else: for index in range(len(tensor_recv_prev)): tensor_recv_prev[index] = ( gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() ) if recv_next and recv_next_split: if isinstance(tensor_recv_next, torch.Tensor): tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() else: for index in range(len(tensor_recv_next)): tensor_recv_next[index] = ( gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() ) return tensor_recv_prev, tensor_recv_next def recv_forward( input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False ) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. prev_rank (int, optional): The rank of the source of the tensor. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. """ input_tensor, _ = _communicate( recv_prev=True, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return input_tensor def recv_backward( output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False ) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. next_rank (int, optional): The rank of the source of the tensor. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list. """ _, output_tensor_grad = _communicate( recv_next=True, recv_next_shape=output_grad_shape, next_rank=next_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return output_tensor_grad def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: """Sends the input tensor to the next stage in pipeline. Args: output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: """Sends the gradient tensor to the previous stage in pipeline. Args: input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent prev_rank (int, optional): The rank of the recipient of the tensor """ _communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors) def send_forward_recv_backward( output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False ) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. Args: output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. """ _, output_tensor_grad = _communicate( object_send_next=output_tensor, recv_next=output_grad_shape is not None, recv_next_shape=output_grad_shape, next_rank=next_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return output_tensor_grad def send_backward_recv_forward( input_tensor_grad, input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. Args: input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. """ input_tensor, _ = _communicate( object_send_prev=input_tensor_grad, recv_prev=input_tensor_shape is not None, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return input_tensor def send_forward_recv_forward( output_tensor, input_tensor_shape, prev_rank=None, next_rank=None, dtype=torch.float, scatter_gather_tensors=False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. Args: output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. """ input_tensor, _ = _communicate( object_send_next=output_tensor, recv_prev=input_tensor_shape is not None, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, next_rank=next_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return input_tensor def send_backward_recv_backward( input_tensor_grad, output_grad_shape, prev_rank=None, next_rank=None, dtype=torch.float, scatter_gather_tensors=False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the gradient tensor from the next member in pipeline as the input of this stage. Args: input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. """ _, output_tensor_grad = _communicate( object_send_prev=input_tensor_grad, recv_next=output_grad_shape is not None, recv_next_shape=output_grad_shape, prev_rank=prev_rank, next_rank=next_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return output_tensor_grad def send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, input_tensor_shape, output_grad_shape, prev_rank=None, next_rank=None, dtype=torch.float, scatter_gather_tensors=False, ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline and the gradient tensor to the previous stage, while receives the input gradient tensor from the next stage and the input tensor from the previous stage. Args: output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next. input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous. input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous. output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next. Returns: Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) """ input_tensor, output_tensor_grad = _communicate( object_send_next=output_tensor, object_send_prev=input_tensor_grad, recv_prev=input_tensor_shape is not None, recv_next=output_grad_shape is not None, recv_prev_shape=input_tensor_shape, recv_next_shape=output_grad_shape, prev_rank=prev_rank, next_rank=next_rank, dtype=dtype, scatter_gather_tensors=scatter_gather_tensors, ) return input_tensor, output_tensor_grad def send_forward_and_recv_next_forward_async( output_tensor, recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, dtype: torch.dtype = None, scatter_gather_tensors=False, ): """send forward output to next rank and recv forward input from prev rank""" reqs = [] tensor_recv_prev = None # prepare send opreations if output_tensor is not None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors) if isinstance(output_tensor, torch.Tensor): reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank)) else: for tensor_to_comm in output_tensor: reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank)) # prepare receive opreations if recv_prev_shape is not None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) # create receive buffer tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( recv_prev_shape, dtype, scatter_gather_tensors ) # generate async receive opterations if isinstance(tensor_recv_prev, torch.Tensor): reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)) else: for tensor_to_comm in tensor_recv_prev: reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank)) if len(reqs) > 0: reqs = dist.batch_isend_irecv(reqs) # return and do other things yield # check communication completed for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv() torch.cuda.synchronize() # Process received data if recv_prev_shape is not None and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() else: for index in range(len(tensor_recv_prev)): tensor_recv_prev[index] = ( gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() ) yield tensor_recv_prev def send_backward_and_recv_next_backward_async( input_tensor, recv_next_shape: Union[torch.Size, List[torch.Size]] = None, dtype: torch.dtype = None, scatter_gather_tensors=False, ): reqs = [] tensor_recv_next = None # prepare send opreations if input_tensor is not None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors) if isinstance(input_tensor, torch.Tensor): reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank)) else: for tensor_to_comm in input_tensor: reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank)) # prepare receive opreations if recv_next_shape is not None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) # create receive buffer tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( recv_next_shape, dtype, scatter_gather_tensors ) # generate async receive opreations if isinstance(tensor_recv_next, torch.Tensor): reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank)) else: for tensor_to_comm in tensor_recv_next: reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank)) if len(reqs) > 0: reqs = dist.batch_isend_irecv(reqs) # return and do other things yield # check communication completed for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv() torch.cuda.synchronize() # Process received data if recv_next_shape is not None and recv_next_split: if isinstance(tensor_recv_next, torch.Tensor): tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() else: for index in range(len(tensor_recv_next)): tensor_recv_next[index] = ( gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() ) yield tensor_recv_next class AsynCommunicator: """AsynCommunicator for managing async communication.""" def __init__( self, tensor_to_send: Union[torch.Tensor, List[torch.Tensor]], recv_shape: Union[torch.Size, List[torch.Size]], dtype: torch.dtype = None, scatter_gather_tensors=False, forward: bool = True, ) -> None: self._need_receive = recv_shape is not None if forward: self._coroutine = send_forward_and_recv_next_forward_async( tensor_to_send, recv_shape, dtype, scatter_gather_tensors ) else: self._coroutine = send_backward_and_recv_next_backward_async( tensor_to_send, recv_shape, dtype, scatter_gather_tensors ) @property def need_receive(self) -> bool: return self._need_receive def start(self) -> None: next(self._coroutine) def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]: received = next(self._coroutine) self._coroutine.close() return received