mirror of https://github.com/hpcaitech/ColossalAI
225 lines
7.9 KiB
Python
225 lines
7.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import io
|
|
import pickle
|
|
from typing import Any, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
from torch.distributed import distributed_c10d as c10d
|
|
|
|
from .stage_manager import PipelineStageManager
|
|
|
|
_unpickler = pickle.Unpickler
|
|
|
|
|
|
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
|
"""transform tensor to object with unpickle.
|
|
Info of the device in bytes stream will be modified into current device before unpickling
|
|
|
|
Args:
|
|
tensor (:class:`torch.tensor`): tensor to be unpickled
|
|
tensor_size (:class:`torch.Size`): Size of the real info in bytes
|
|
|
|
Returns:
|
|
Any: object after unpickled
|
|
"""
|
|
buf = tensor.numpy().tobytes()[:tensor_size]
|
|
if b'cuda' in buf:
|
|
buf_array = bytearray(buf)
|
|
device_index = torch.cuda.current_device()
|
|
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
|
|
buf = bytes(buf_array)
|
|
|
|
io_bytes = io.BytesIO(buf)
|
|
byte_pickler = _unpickler(io_bytes)
|
|
unpickle = byte_pickler.load()
|
|
|
|
return unpickle
|
|
|
|
|
|
def _broadcast_object_list(object_list: List[Any],
|
|
src: int,
|
|
group: ProcessGroup,
|
|
device: Optional[Union[torch.device, str, int]] = None):
|
|
"""This is a modified version of the broadcast_object_list in torch.distribution
|
|
The only difference is that object will be move to correct device after unpickled.
|
|
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
|
be updated with data sent from rank src.
|
|
|
|
Args:
|
|
object_list (List[Any]): list of object to broadcast
|
|
src (int): source rank to broadcast
|
|
dst (int): dst rank to broadcast
|
|
device (:class:`torch.device`): device to do broadcast. current device in default
|
|
|
|
"""
|
|
|
|
if c10d._rank_not_in_group(group):
|
|
c10d._warn_not_in_group("broadcast_object_list")
|
|
return
|
|
|
|
my_rank = dist.get_rank()
|
|
# Serialize object_list elements to tensors on src rank.
|
|
if my_rank == src:
|
|
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
|
object_sizes_tensor = torch.cat(size_list)
|
|
else:
|
|
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
|
|
|
is_nccl_backend = c10d._check_for_nccl_backend(group)
|
|
current_device = None
|
|
|
|
if device is not None:
|
|
if is_nccl_backend and device.type != "cuda":
|
|
raise ValueError("device type must be cuda for nccl backend")
|
|
current_device = device
|
|
else:
|
|
current_device = torch.device("cpu")
|
|
if is_nccl_backend:
|
|
current_device = torch.device("cuda", torch.cuda.current_device())
|
|
if is_nccl_backend:
|
|
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
|
|
|
# Broadcast object sizes
|
|
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)
|
|
|
|
# Concatenate and broadcast serialized object tensors
|
|
if my_rank == src:
|
|
object_tensor = torch.cat(tensor_list)
|
|
else:
|
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
|
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
|
dtype=torch.uint8,
|
|
)
|
|
|
|
if is_nccl_backend:
|
|
object_tensor = object_tensor.to(current_device)
|
|
|
|
c10d.broadcast(object_tensor, src=src, group=group, async_op=False)
|
|
|
|
# Deserialize objects using their stored sizes.
|
|
offset = 0
|
|
|
|
if my_rank != src:
|
|
for i, obj_size in enumerate(object_sizes_tensor):
|
|
obj_view = object_tensor[offset:offset + obj_size]
|
|
obj_view = obj_view.type(torch.uint8)
|
|
if obj_view.device != torch.device("cpu"):
|
|
obj_view = obj_view.cpu()
|
|
offset += obj_size
|
|
# unpickle
|
|
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
|
|
|
|
# unconsistence in device
|
|
if isinstance(unpickle_object,
|
|
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
|
unpickle_object = unpickle_object.cuda()
|
|
|
|
object_list[i] = unpickle_object
|
|
|
|
|
|
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
|
"""send anything to dst rank
|
|
|
|
Args:
|
|
object (Any): object needed to be sent
|
|
dst (int): rank of the destination
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# then broadcast safely
|
|
_broadcast_object_list([object], src, group)
|
|
|
|
|
|
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
|
"""recv anything from src
|
|
|
|
Args:
|
|
src (int): source rank of data. local rank will receive data from src rank.
|
|
|
|
Returns:
|
|
Any: Object received from src.
|
|
"""
|
|
object_list = [None]
|
|
_broadcast_object_list(object_list, src, group)
|
|
|
|
return object_list[0]
|
|
|
|
|
|
class PipelineP2PCommunication:
|
|
|
|
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
|
self.stage_manager = stage_manager
|
|
|
|
def recv_forward(self, prev_rank: int = None) -> Any:
|
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
|
|
|
Args:
|
|
prev_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
Returns:
|
|
Any: The input tensor or input tensor list.
|
|
"""
|
|
if self.stage_manager.is_first_stage():
|
|
input_tensor = None
|
|
else:
|
|
if prev_rank is None:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
cur_rank = self.stage_manager.get_rank()
|
|
input_tensor = _recv_object(prev_rank, cur_rank,
|
|
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
|
|
|
|
return input_tensor
|
|
|
|
def recv_backward(self, next_rank: int = None) -> Any:
|
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
|
|
|
Args:
|
|
next_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
Returns:
|
|
Any: The input gradient tensor or gradient tensor list.
|
|
"""
|
|
if self.stage_manager.is_last_stage():
|
|
output_tensor_grad = None
|
|
else:
|
|
if next_rank is None:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
cur_rank = self.stage_manager.get_rank()
|
|
output_tensor_grad = _recv_object(next_rank, cur_rank,
|
|
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
|
|
|
|
return output_tensor_grad
|
|
|
|
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
|
"""Sends the input tensor to the next stage in pipeline.
|
|
|
|
Args:
|
|
output_object (Any): Object to be sent.
|
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
|
"""
|
|
if not self.stage_manager.is_last_stage():
|
|
if next_rank is None:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
cur_rank = self.stage_manager.get_rank()
|
|
_send_object(output_object, cur_rank, next_rank,
|
|
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
|
|
|
|
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
|
|
|
Args:
|
|
input_object (Any): Object to be sent.
|
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
|
"""
|
|
if not self.stage_manager.is_first_stage():
|
|
if prev_rank is None:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
cur_rank = self.stage_manager.get_rank()
|
|
_send_object(input_object, cur_rank, prev_rank,
|
|
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|