[pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
pull/5247/head
Elsa Granger 2024-01-08 15:37:27 +08:00 committed by GitHub
parent 7bc6969ce6
commit d565df3821
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 136 deletions

View File

@ -5,20 +5,17 @@ import io
import pickle
import re
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from .stage_manager import PipelineStageManager
_unpickler = pickle.Unpickler
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
"""transform tensor to object with unpickle.
@ -42,7 +39,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = bytes(buf_array)
io_bytes = io.BytesIO(buf)
byte_pickler = _unpickler(io_bytes)
byte_pickler = pickle.Unpickler(io_bytes)
unpickle = byte_pickler.load()
return unpickle
@ -67,7 +64,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list")
return
is_nccl_backend = check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_backend(group)
current_device = None
if device is not None:
@ -133,45 +130,61 @@ def _broadcast_object_list(
object_list[i] = unpickle_object
def check_for_nccl_backend(group):
def _check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def check_device(group):
is_nccl_backend = check_for_nccl_backend(group)
current_device = None
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
class P2PDataType(Enum):
Serialization = 0
Tensor = 1
List = 2
Dict = 3
def create_send_metadata(
object: Any, strict: bool = True, return_tensor: bool = False
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
"""
Args:
object (Any): object needed to be sent
strict (bool, optional): whether to check if the object is supported for fast send
return_tensor (bool, optional): whether to return tensor objects
"""
objs, tree_spec = tree_flatten(object)
tensor_metadata, tensor_objs = [], []
non_tensor_obj_idx, non_tensor_objs = [], []
for idx, obj in enumerate(objs):
if isinstance(obj, torch.Tensor):
tensor_objs.append(obj)
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
else:
non_tensor_obj_idx.append(idx)
non_tensor_objs.append(obj)
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
return metadata if not return_tensor else (metadata, tensor_objs)
@dataclass
class P2PMetadata:
data_type: P2PDataType
content: Union[List[TensorMetadata], TensorMetadata, Any]
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
def _filling_ops_queue(
obj: Union[torch.Tensor, List[torch.Tensor]],
comm_op: Callable,
comm_rank: int,
ops_queue: List,
group: ProcessGroup,
):
if isinstance(obj, torch.Tensor):
obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
@ -179,47 +192,22 @@ def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: Li
else:
for tensor_to_comm in obj:
assert isinstance(tensor_to_comm, torch.Tensor)
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
if p2p_metadata.data_type == P2PDataType.Tensor:
metadata = p2p_metadata.content
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
buffer_recv = []
for metadata in tensor_metadata:
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
buffer_recv = []
for metadata in p2p_metadata.content:
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
buffer_recv.append(tensor_recv)
return buffer_recv
else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def create_fast_send_metadata(object: Any) -> P2PMetadata:
assert _check_if_fast_send_available(object)
if isinstance(object, torch.Tensor):
data_type = P2PDataType.Tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif isinstance(object, list):
data_type = P2PDataType.List
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
elif isinstance(object, dict):
data_type = P2PDataType.Dict
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
else:
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
return P2PMetadata(data_type, content)
buffer_recv.append(tensor_recv)
return buffer_recv
def _batch_send_recv_tensor(
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
recv_tensor_metadata: Optional[P2PMetadata],
send_tensor_list: Optional[List[torch.Tensor]],
recv_tensor_metadata: Optional[List[TensorMetadata]],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
@ -227,16 +215,16 @@ def _batch_send_recv_tensor(
current_device: Any,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None and send_tensor_list is not None:
assert send_group is not None
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
assert recv_group is not None
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
@ -247,13 +235,13 @@ def _batch_send_recv_tensor(
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
torch.cuda.synchronize()
# torch.cuda.synchronize()
return buffer_recv
def _send_recv_serialization_object(
object: Any,
object: Optional[P2PMetadata],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
@ -274,14 +262,14 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
@ -289,19 +277,19 @@ def _send_recv_serialization_object(
req.wait()
# See the comment in `_batch_send_recv_tensor`
torch.cuda.synchronize()
# torch.cuda.synchronize()
ops = []
if send_dst is not None and send_object_tensor is not None:
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
@ -309,7 +297,7 @@ def _send_recv_serialization_object(
req.wait()
# See the comment in `_batch_send_recv_tensor`
torch.cuda.synchronize()
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
@ -324,18 +312,6 @@ def _send_recv_serialization_object(
return unpickle_object
def _check_if_fast_send_available(object: Any) -> bool:
if isinstance(object, torch.Tensor):
return True
elif isinstance(object, list):
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
return is_list_of_tensor
elif isinstance(object, dict):
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
return is_dict_of_tensor
return False
def _communicate(
object: Any,
send_dst: Optional[int],
@ -361,10 +337,15 @@ def _communicate(
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
assert (
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
), "metadata_recv type must not be Serialization"
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
), "metadata_recv should not contain non-tensor objects"
metadata_send, tensor_objs = None, None
if object is not None:
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
@ -372,9 +353,13 @@ def _communicate(
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
@ -387,8 +372,8 @@ def _communicate(
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = check_device(send_group)
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
@ -396,14 +381,6 @@ def _communicate(
current_device = current_send_device
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
metadata_send = None
if send_dst is not None and send_metadata:
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
if not can_fast_send:
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
else:
metadata_send = create_fast_send_metadata(object)
# Send and receive metadata
_metadata_recv = _send_recv_serialization_object(
object=metadata_send,
@ -417,31 +394,26 @@ def _communicate(
assert metadata_recv is None or _metadata_recv is None
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
send_tensor_list = None
if isinstance(object, torch.Tensor):
send_tensor_list = object
elif isinstance(object, list):
send_tensor_list = object
elif isinstance(object, dict):
send_tensor_list = list(object.values())
# Send and receive data
recv_buffer = _batch_send_recv_tensor(
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
)
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
if metadata_recv.data_type == P2PDataType.Serialization:
return metadata_recv.content
else:
assert recv_buffer is not None
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
return recv_buffer
elif metadata_recv.data_type == P2PDataType.Dict:
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
else:
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
tree_spec = metadata_recv.tree_spec
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
non_tensor_objs = metadata_recv.non_tensor_objs
if recv_tensor_objs is None:
recv_tensor_objs = []
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:

View File

@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule):
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule):
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor

View File

@ -4,7 +4,7 @@ import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
@ -57,19 +57,15 @@ def check_p2p_communication():
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=comm_metadata,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)