mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] A more general _communicate in p2p (#5062)
* A more general _communicate * feat: finish tree_flatten version p2p * fix: update p2p api calls --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>pull/5247/head
parent
7bc6969ce6
commit
d565df3821
|
@ -5,20 +5,17 @@ import io
|
|||
import pickle
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
||||
"""transform tensor to object with unpickle.
|
||||
|
@ -42,7 +39,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||
buf = bytes(buf_array)
|
||||
|
||||
io_bytes = io.BytesIO(buf)
|
||||
byte_pickler = _unpickler(io_bytes)
|
||||
byte_pickler = pickle.Unpickler(io_bytes)
|
||||
unpickle = byte_pickler.load()
|
||||
|
||||
return unpickle
|
||||
|
@ -67,7 +64,7 @@ def _broadcast_object_list(
|
|||
c10d._warn_not_in_group("broadcast_object_list")
|
||||
return
|
||||
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
if device is not None:
|
||||
|
@ -133,45 +130,61 @@ def _broadcast_object_list(
|
|||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
def _check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
# It is not expected for PG to be wrapped many times, but support it just in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def check_device(group):
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
def _check_device(group):
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
|
||||
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
|
||||
|
||||
|
||||
class P2PDataType(Enum):
|
||||
Serialization = 0
|
||||
Tensor = 1
|
||||
List = 2
|
||||
Dict = 3
|
||||
def create_send_metadata(
|
||||
object: Any, strict: bool = True, return_tensor: bool = False
|
||||
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
strict (bool, optional): whether to check if the object is supported for fast send
|
||||
return_tensor (bool, optional): whether to return tensor objects
|
||||
"""
|
||||
objs, tree_spec = tree_flatten(object)
|
||||
tensor_metadata, tensor_objs = [], []
|
||||
non_tensor_obj_idx, non_tensor_objs = [], []
|
||||
for idx, obj in enumerate(objs):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_objs.append(obj)
|
||||
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
|
||||
else:
|
||||
non_tensor_obj_idx.append(idx)
|
||||
non_tensor_objs.append(obj)
|
||||
|
||||
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
|
||||
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
|
||||
return metadata if not return_tensor else (metadata, tensor_objs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2PMetadata:
|
||||
data_type: P2PDataType
|
||||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||
|
||||
|
||||
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
|
||||
def _filling_ops_queue(
|
||||
obj: Union[torch.Tensor, List[torch.Tensor]],
|
||||
comm_op: Callable,
|
||||
comm_rank: int,
|
||||
ops_queue: List,
|
||||
group: ProcessGroup,
|
||||
):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||
|
@ -179,47 +192,22 @@ def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: Li
|
|||
else:
|
||||
for tensor_to_comm in obj:
|
||||
assert isinstance(tensor_to_comm, torch.Tensor)
|
||||
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||
|
||||
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
|
||||
if p2p_metadata.data_type == P2PDataType.Tensor:
|
||||
metadata = p2p_metadata.content
|
||||
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
|
||||
buffer_recv = []
|
||||
for metadata in tensor_metadata:
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
return tensor_recv
|
||||
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
|
||||
buffer_recv = []
|
||||
for metadata in p2p_metadata.content:
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||
|
||||
|
||||
def create_fast_send_metadata(object: Any) -> P2PMetadata:
|
||||
assert _check_if_fast_send_available(object)
|
||||
if isinstance(object, torch.Tensor):
|
||||
data_type = P2PDataType.Tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif isinstance(object, list):
|
||||
data_type = P2PDataType.List
|
||||
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
|
||||
elif isinstance(object, dict):
|
||||
data_type = P2PDataType.Dict
|
||||
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
|
||||
else:
|
||||
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
|
||||
return P2PMetadata(data_type, content)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(
|
||||
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
recv_tensor_metadata: Optional[P2PMetadata],
|
||||
send_tensor_list: Optional[List[torch.Tensor]],
|
||||
recv_tensor_metadata: Optional[List[TensorMetadata]],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
|
@ -227,16 +215,16 @@ def _batch_send_recv_tensor(
|
|||
current_device: Any,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
|
||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
|
@ -247,13 +235,13 @@ def _batch_send_recv_tensor(
|
|||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
object: Any,
|
||||
object: Optional[P2PMetadata],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
|
@ -274,14 +262,14 @@ def _send_recv_serialization_object(
|
|||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
|
@ -289,19 +277,19 @@ def _send_recv_serialization_object(
|
|||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
|
@ -309,7 +297,7 @@ def _send_recv_serialization_object(
|
|||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
|
@ -324,18 +312,6 @@ def _send_recv_serialization_object(
|
|||
return unpickle_object
|
||||
|
||||
|
||||
def _check_if_fast_send_available(object: Any) -> bool:
|
||||
if isinstance(object, torch.Tensor):
|
||||
return True
|
||||
elif isinstance(object, list):
|
||||
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
|
||||
return is_list_of_tensor
|
||||
elif isinstance(object, dict):
|
||||
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
|
||||
return is_dict_of_tensor
|
||||
return False
|
||||
|
||||
|
||||
def _communicate(
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
|
@ -361,10 +337,15 @@ def _communicate(
|
|||
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
||||
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
||||
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
||||
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
|
||||
assert (
|
||||
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
|
||||
), "metadata_recv type must not be Serialization"
|
||||
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
|
||||
), "metadata_recv should not contain non-tensor objects"
|
||||
|
||||
metadata_send, tensor_objs = None, None
|
||||
if object is not None:
|
||||
# NOTE: if object contains non-tensor objects, we have to send metadata
|
||||
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
||||
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
|
@ -372,9 +353,13 @@ def _communicate(
|
|||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
return _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
else:
|
||||
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
recv_data = _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
|
@ -387,8 +372,8 @@ def _communicate(
|
|||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
||||
|
||||
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
|
||||
|
||||
|
@ -396,14 +381,6 @@ def _communicate(
|
|||
current_device = current_send_device
|
||||
|
||||
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
||||
metadata_send = None
|
||||
if send_dst is not None and send_metadata:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
|
||||
else:
|
||||
metadata_send = create_fast_send_metadata(object)
|
||||
|
||||
# Send and receive metadata
|
||||
_metadata_recv = _send_recv_serialization_object(
|
||||
object=metadata_send,
|
||||
|
@ -417,31 +394,26 @@ def _communicate(
|
|||
assert metadata_recv is None or _metadata_recv is None
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
send_tensor_list = None
|
||||
if isinstance(object, torch.Tensor):
|
||||
send_tensor_list = object
|
||||
elif isinstance(object, list):
|
||||
send_tensor_list = object
|
||||
elif isinstance(object, dict):
|
||||
send_tensor_list = list(object.values())
|
||||
|
||||
# Send and receive data
|
||||
recv_buffer = _batch_send_recv_tensor(
|
||||
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
|
||||
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
||||
recv_tensor_objs = _batch_send_recv_tensor(
|
||||
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
|
||||
)
|
||||
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
if metadata_recv.data_type == P2PDataType.Serialization:
|
||||
return metadata_recv.content
|
||||
else:
|
||||
assert recv_buffer is not None
|
||||
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
|
||||
return recv_buffer
|
||||
elif metadata_recv.data_type == P2PDataType.Dict:
|
||||
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
|
||||
else:
|
||||
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||
tree_spec = metadata_recv.tree_spec
|
||||
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
|
||||
non_tensor_objs = metadata_recv.non_tensor_objs
|
||||
|
||||
if recv_tensor_objs is None:
|
||||
recv_tensor_objs = []
|
||||
|
||||
for idx in non_tensor_obj_idx:
|
||||
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
||||
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
||||
|
||||
return recv_object
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
|
@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.distributed as dist
|
|||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -57,19 +57,15 @@ def check_p2p_communication():
|
|||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
tensor_metadata = TensorMetadata(
|
||||
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||
)
|
||||
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=comm_metadata,
|
||||
metadata_recv=create_send_metadata(tensor),
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
|
||||
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue