From 2a25a2aff71439a242aff0522f65da6df2805b2a Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 26 Jun 2024 14:48:02 +0800 Subject: [PATCH] [Feature] optimize PP overlap (#5735) * update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +- colossalai/cluster/process_group_mesh.py | 4 +- colossalai/pipeline/p2p.py | 361 +++++++++++------- .../pipeline/schedule/interleaved_pp.py | 300 ++++++++------- colossalai/pipeline/schedule/one_f_one_b.py | 35 +- colossalai/pipeline/stage_manager.py | 39 +- examples/language/llama/benchmark.py | 44 ++- tests/test_pipeline/test_p2p_communication.py | 20 +- tests/test_pipeline/test_stage_manager.py | 2 +- 9 files changed, 456 insertions(+), 357 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fa3c3646a..3bd43f172 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism """ def __init__( @@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, + overlap_p2p: bool = True, ) -> None: super().__init__() assert ( @@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase): assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, @@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_microbatch=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index fea4a23ba..f0cb78c5f 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -134,7 +134,7 @@ class ProcessGroupMesh: """ assert mode in ["raise", "wrap", "clip"] - return np.ravel_multi_index(coord, shape, mode) + return int(np.ravel_multi_index(coord, shape, mode)) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. @@ -182,7 +182,7 @@ class ProcessGroupMesh: axis = [ axis, ] - assert isinstance(indices_at_axis[0], int) + assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}." indices_at_axis = [ indices_at_axis, ] diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 1b55b140c..ed190eb08 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -225,31 +225,41 @@ def _batch_send_recv_tensor( send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], current_device: Any, + overlap_p2p: bool = True, + send_first: bool = True, ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: buffer_recv = None if recv_tensor_metadata is not None: buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device) ops = [] - if send_dst is not None and send_tensor_list is not None: - assert send_group is not None - _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) - if recv_src is not None and buffer_recv is not None: - assert recv_group is not None - _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + is_send = send_dst is not None and send_tensor_list is not None + is_recv = recv_src is not None and buffer_recv is not None + + if send_first: + if is_send: + assert send_group is not None + _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) + if is_recv: + assert recv_group is not None + _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + else: + if is_recv: + assert recv_group is not None + _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + if is_send: + assert send_group is not None + _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # Remove synchronization according to Pytorch's documentation - # However, the Megatron-LM does synchronization here - # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 - # In case there is potential error, uncomment the following `torch.cuda.synchronize()` - # torch.cuda.synchronize() - - return buffer_recv + if not overlap_p2p: + for req in reqs: + req.wait() + return buffer_recv, [] + else: + return buffer_recv, reqs + return None, [] def _send_recv_serialization_object( @@ -260,10 +270,11 @@ def _send_recv_serialization_object( recv_group: Optional[ProcessGroup], current_device: Any, is_nccl_backend: bool, + send_first: bool = True, ) -> Optional[P2PMetadata]: ops = [] - send_object_tensor = None + send_object_size_tensor = None if object is not None and send_dst is not None: if Version(torch.__version__) >= Version("1.13.0"): send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) @@ -274,43 +285,54 @@ def _send_recv_serialization_object( send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device) - _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) - recv_object_size_tensor = None if recv_src is not None: recv_object_size_tensor = torch.empty(1, dtype=torch.long) if is_nccl_backend: recv_object_size_tensor = recv_object_size_tensor.to(current_device) - _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + + if send_first: + if send_object_size_tensor is not None: + _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) + if recv_src is not None: + _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + else: + if recv_src is not None: + _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + if send_object_size_tensor is not None: + _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: - req.wait() - - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + req.wait() # This blocks the compute stream in torch ops = [] - - if send_dst is not None and send_object_tensor is not None: - _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + is_send = send_dst is not None and send_object_tensor is not None + is_recv = recv_src is not None and recv_object_size_tensor is not None recv_object_tensor = None - if recv_src is not None and recv_object_size_tensor is not None: + if is_recv: recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) if is_nccl_backend: recv_object_tensor = recv_object_tensor.to(current_device) - _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + + if send_first: + if is_send: + _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + if is_recv: + _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + else: + if is_recv: + _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + if is_send: + _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() - if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) if recv_object_tensor.device != torch.device("cpu"): @@ -328,11 +350,12 @@ def _communicate( object: Any, send_dst: Optional[int], recv_src: Optional[int], + overlap_p2p: bool, send_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, + send_first: Optional[bool] = None, ) -> Any: """ Send and receive object from send_dst and recv_src respectively @@ -341,6 +364,7 @@ def _communicate( object (Any): object needed to be sent send_dst (int): rank of the destination recv_src (int): rank of the source + overlap_p2p (bool): whether to overlap p2p communication with computation send_group (ProcessGroup, optional): process group of sender recv_group (ProcessGroup, optional): process group of receiver send_metadata (bool, optional): whether to send metadata @@ -358,32 +382,10 @@ def _communicate( # NOTE: if object contains non-tensor objects, we have to send metadata metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True) send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0 + else: + send_metadata = False - # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, - # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. - if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - assert send_prior_fallback is not None, "Priority must be set if fallback happens" - if send_prior_fallback: - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return _communicate( - None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv - ) - else: - recv_data = _communicate( - None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv - ) - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return recv_data - - # NOTE: only the following 5 cases are valid: - # 1. send() [needs extra metadata] and no recv() - # 2. recv() [needs extra metadata] and no send() - # 3. neither send() nor recv() need extra metadata - assert not (send_dst is not None and send_metadata) or recv_src is None - assert not (recv_src is not None and metadata_recv is None) or send_dst is None - assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) - current_send_device, is_send_nccl_backend = _check_device(send_group) current_recv_device, is_recv_nccl_backend = _check_device(recv_group) @@ -402,14 +404,25 @@ def _communicate( recv_group=recv_group if metadata_recv is None else None, current_device=current_device, is_nccl_backend=is_nccl_backend, + send_first=send_first if send_first != None else True, ) - assert metadata_recv is None or _metadata_recv is None + assert ( + metadata_recv is None or _metadata_recv is None + ), "You shouldn't receive metadata when using the cached metadata" metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv # Send and receive data recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata - recv_tensor_objs = _batch_send_recv_tensor( - tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device + recv_tensor_objs, wait_handles = _batch_send_recv_tensor( + tensor_objs, + recv_tensor_metadata, + send_dst, + recv_src, + send_group, + recv_group, + current_device, + overlap_p2p=overlap_p2p, + send_first=send_first if send_first != None else True, ) if metadata_recv is not None: @@ -424,33 +437,9 @@ def _communicate( for idx in non_tensor_obj_idx: recv_tensor_objs.insert(idx, non_tensor_objs.pop(0)) recv_object = tree_unflatten(recv_tensor_objs, tree_spec) + return recv_object, wait_handles - return recv_object - - -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: - """send anything to dst rank - - Args: - object (Any): object needed to be sent - dst (int): rank of the destination - - Returns: - None - """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs) - - -def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any: - """recv anything from src - - Args: - src (int): source rank of data. local rank will receive data from src rank. - - Returns: - Any: Object received from src. - """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs) + return None, wait_handles def _p2p_comm( @@ -532,10 +521,13 @@ def _p2p_comm( class PipelineP2PCommunication: - def __init__(self, stage_manager: PipelineStageManager) -> None: + def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None: self.stage_manager = stage_manager + self.overlap_p2p = overlap_p2p - def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: + def recv_forward( + self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None + ) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -543,95 +535,186 @@ class PipelineP2PCommunication: Returns: Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. """ if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object( - prev_rank, - cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), + input_tensor, wait_handles = _communicate( + object=None, + recv_src=prev_rank, + send_dst=None, + recv_group=self.stage_manager.get_p2p_process_group(), metadata_recv=metadata_recv, + overlap_p2p=self.overlap_p2p, ) - return input_tensor + return input_tensor, wait_handles - def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: + def recv_backward( + self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None + ) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. - Args: next_rank (int, optional): The rank of the source of the tensor. Returns: - Any: The input gradient tensor or gradient tensor list. + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object( - next_rank, - cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank), + + output_tensor_grad, wait_handles = _communicate( + object=None, + recv_src=next_rank, + send_dst=None, + recv_group=self.stage_manager.get_p2p_process_group(), metadata_recv=metadata_recv, + overlap_p2p=self.overlap_p2p, ) - return output_tensor_grad + return output_tensor_grad, wait_handles - def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None: + def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List: """Sends the input tensor to the next stage in pipeline. Args: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + List: List of handles for the communication requests, if overlap is enabled. """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object( + _, handles = _communicate( output_object, - cur_rank, - next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + recv_src=None, + send_dst=next_rank, + send_group=self.stage_manager.get_p2p_process_group(), send_metadata=send_metadata, + overlap_p2p=self.overlap_p2p, ) + return handles - def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: + def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List: """Sends the gradient tensor to the previous stage in pipeline. Args: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + List: List of handles for the communication requests, if overlap is enabled. """ if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object( + _, handles = _communicate( input_object, - cur_rank, - prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), + recv_src=None, + send_dst=prev_rank, + send_group=self.stage_manager.get_p2p_process_group(), send_metadata=send_metadata, + overlap_p2p=self.overlap_p2p, ) + return handles - def send_forward_recv_backward( + def send_forward_recv_forward( + self, + output_object: Any, + is_send: bool, + is_recv: bool, + send_first: bool, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Tuple[Any, List]: + """Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage + + Args: + output_object (Any): Object to be sent. + is_send (bool): Whether to send the input tensor to the next pipeline stage. + is_recv (bool): Whether to copy the output tensor from the next pipeline stage. + send_first (bool): Whether to send before receive. + send_metadata (bool, optional): Whether to send metadata. + metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received. + + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + next_rank = self.stage_manager.get_next_rank() if is_send else None + prev_rank = self.stage_manager.get_prev_rank() if is_recv else None + group = self.stage_manager.get_p2p_process_group() + return _communicate( + output_object, + send_dst=next_rank, + recv_src=prev_rank, + send_group=group if is_send else None, + recv_group=group if is_recv else None, + send_metadata=send_metadata if is_send else False, + metadata_recv=metadata_recv if is_recv else None, + send_first=send_first, + overlap_p2p=self.overlap_p2p, + ) + + def send_backward_recv_backward( self, input_object: Any, - next_rank: Optional[int] = None, + is_send: bool, + is_recv: bool, + send_first: bool, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: - """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline + ) -> Tuple[Any, List]: + """Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage Args: input_object (Any): Object to be sent. - next_rank (int, optional): The rank of the sender and recipient of the tensor + is_send (bool): Whether to send the gradient tensor to the previous pipeline stage. + is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage. + send_first (bool): Whether to send before receive. + send_metadata (bool, optional): Whether to send metadata. + metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received. + + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. """ - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() + prev_rank = self.stage_manager.get_prev_rank() if is_send else None + next_rank = self.stage_manager.get_next_rank() if is_recv else None + + group = self.stage_manager.get_p2p_process_group() - cur_rank = self.stage_manager.get_rank() - group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + return _communicate( + input_object, + send_dst=prev_rank, + recv_src=next_rank, + send_group=group if is_send else None, + recv_group=group if is_recv else None, + send_metadata=send_metadata if is_send else False, + metadata_recv=metadata_recv if is_recv else None, + send_first=send_first, + overlap_p2p=self.overlap_p2p, + ) + + def send_forward_recv_backward( + self, + input_object: Any, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + send_first: Optional[bool] = None, + ) -> Tuple[Any, List]: + """Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage + + Args: + input_object (Any): Object to be sent. + + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + next_rank = self.stage_manager.get_next_rank() + group = self.stage_manager.get_p2p_process_group() return _communicate( input_object, next_rank, @@ -640,28 +723,28 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, + overlap_p2p=False, ) def send_backward_recv_forward( self, input_object: Any, - prev_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: + send_first: Optional[bool] = None, + ) -> Tuple[Any, List]: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline Args: input_object (Any): Object to be sent. - prev_rank (int, optional): The rank of the sender and recipient of the tensor - """ - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + prev_rank = self.stage_manager.get_prev_rank() + group = self.stage_manager.get_p2p_process_group() return _communicate( input_object, prev_rank, @@ -670,7 +753,8 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, + overlap_p2p=False, ) def p2p_communicate( @@ -679,7 +763,7 @@ class PipelineP2PCommunication: recv_pre: bool, next_rank: Optional[int] = None, comm_dtype: torch.dtype = torch.float16, - ) -> None: + ) -> Any: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -689,12 +773,11 @@ class PipelineP2PCommunication: """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() recv_tensor = _p2p_comm( output_object, recv_pre, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + self.stage_manager.get_p2p_process_group(), comm_dtype, ) return recv_tensor diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a4ace5e1b..a21b45c44 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -1,8 +1,9 @@ from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda +import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map @@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_ from .base import PipelineSchedule +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + class InterleavedSchedule(PipelineSchedule): def __init__( self, @@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule): num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + overlap_p2p: bool = True, ) -> None: super().__init__(stage_manager) assert ( num_microbatch is not None or microbatch_size is not None ), "Either num_microbatch or microbatch_size should be provided" - self.comm = PipelineP2PCommunication(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + self.overlap_p2p = overlap_p2p self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size self.num_model_chunks = num_model_chunks @@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule): Returns: int: The model chunk idx of the input microbatch_id """ - assert microbatch_id < self.num_microbatch * self.num_model_chunks + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: + # Reverse order model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For interleaved 1F1B. @@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input tensor or input tensor list. + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) - return input_tensor + return input_tensor, wait_handles + return None, [] - def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. For interleaved 1F1B. @@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + output_tensor_grad, wait_handles = self.comm.recv_backward( + next_rank, metadata_recv=self.grad_metadata_recv + ) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles - return output_tensor_grad + return None, [] - def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule): model_chunk_id (int): The current model chunk idx. output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + return send_handles + return [] - def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule): model_chunk_id (int): The current model chunk idx. input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + send_handles = self.comm.send_backward( + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata + ) self.send_grad_metadata = not self.enable_metadata_cache + return send_handles + return [] - def send_forward_recv_backward( - self, - model_chunk_id_send: int, - model_chunk_id_recv: int, - output_tensor: Any, - next_rank: Optional[int] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: + def send_forward_recv_forward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True + ) -> Tuple[Any, List]: with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): - send_data = not self.stage_manager.is_last_stage() + is_send = not self.stage_manager.is_last_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): - recv_data = not self.stage_manager.is_last_stage() - - if send_data and recv_data: - if not self.send_forward_recv_backward and self.grad_metadata_recv is not None: - send_prior_fallback = None # must not fallback - output_tensor_grad = self.comm.send_forward_recv_backward( - output_tensor, - next_rank, - send_metadata=self.send_tensor_metadata, - metadata_recv=self.grad_metadata_recv, - send_prior_fallback=send_prior_fallback, - ) - self.send_tensor_metadata = not self.enable_metadata_cache - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) - return output_tensor_grad + is_recv = not self.stage_manager.is_first_stage() + input_tensor, wait_handles = self.comm.send_forward_recv_forward( + output_tensor, + is_send, + is_recv, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.tensor_metadata_recv, + send_first=send_first, + ) + # Cache metadata + self.send_tensor_metadata = not self.enable_metadata_cache and is_send + if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_send_metadata(input_tensor) + + return input_tensor, wait_handles - # send only or recv only - self.send_forward(model_chunk_id_send, output_tensor) - return self.recv_backward(model_chunk_id_recv) - - def send_backward_recv_forward( - self, - model_chunk_id_send: int, - model_chunk_id_recv: int, - input_tensor_grad: Any, - prev_rank: Optional[int] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: + def send_backward_recv_backward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True + ) -> Tuple[Any, List]: with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): - send_data = not self.stage_manager.is_first_stage() + is_send = not self.stage_manager.is_first_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): - recv_data = not self.stage_manager.is_first_stage() - - if send_data and recv_data: - if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None: - send_prior_fallback = None # must not fallback - input_tensor = self.comm.send_backward_recv_forward( - input_tensor_grad, - prev_rank, - send_metadata=self.send_grad_metadata, - metadata_recv=self.tensor_metadata_recv, - send_prior_fallback=send_prior_fallback, - ) - self.send_grad_metadata = not self.enable_metadata_cache - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) - return input_tensor - - # send only or recv only - self.send_backward(model_chunk_id_send, input_tensor_grad) - return self.recv_forward(model_chunk_id_recv) - - def send_forward_recv_forward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool - ): - if send_prior: - self.send_forward(model_chunk_id_send, output_tensor) - input_tensor = self.recv_forward(model_chunk_id_recv) - else: - input_tensor = self.recv_forward(model_chunk_id_recv) - self.send_forward(model_chunk_id_send, output_tensor) - - return input_tensor - - def send_backward_recv_backward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool - ): - if send_prior: - self.send_backward(model_chunk_id_send, input_tensor_grad) - output_tensor_grad = self.recv_backward(model_chunk_id_recv) - else: - output_tensor_grad = self.recv_backward(model_chunk_id_recv) - self.send_backward(model_chunk_id_send, input_tensor_grad) - - return output_tensor_grad + is_recv = not self.stage_manager.is_last_stage() + output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( + input_tensor_grad, + is_send, + is_recv, + send_metadata=self.send_grad_metadata, + metadata_recv=self.grad_metadata_recv, + send_first=send_first, + ) + # Cache metadata + self.send_grad_metadata = not self.enable_metadata_cache and is_send + if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles def forward_step( self, @@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule): Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ + # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # for the first stage, input_obj is None - # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + # for other stages, input_obj is the output of the previous stage containing hidden_states etc. + # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): @@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + fwd_wait_handles = [] model_chunk_id = self.get_model_chunk_id(0, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id) for i in range(self.num_microbatch * self.num_model_chunks): - last_iteration = i == self.num_microbatch * self.num_model_chunks - 1 + last_batch = i == self.num_microbatch * self.num_model_chunks - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + + # Wait until current input is received + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if not last_iteration: - input_obj = self.send_forward_recv_forward( + if not last_batch: + input_obj, fwd_wait_handles = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), output_tensor=output_obj, - send_prior=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0, ) else: - self.send_forward(model_chunk_id, output_obj) + fwd_wait_handles = self.send_forward(model_chunk_id, output_obj) if outputs is not None: outputs = merge_batch(outputs) @@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule): self.load_batch(data_iter) num_microbatch = self.num_microbatch * self.num_model_chunks + # Forward + until 1st backward num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + # Steps needed to reach the last chunk num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) num_microbatch_remaining = num_microbatch - num_warmup_microbatch @@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + bwd_wait_handles = [] + # Get the 1st input batch model_chunk_id = self.get_model_chunk_id(0, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id) + # Run warmup forward passes. for i in range(num_warmup_microbatch): - last_iteration = i == num_warmup_microbatch - 1 + last_batch = i == num_warmup_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + + # Wait for input + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - if last_iteration and num_microbatch_remaining == 0: - self.send_forward(model_chunk_id, output_obj) + if last_batch and num_microbatch_remaining == 0: + fwd_wait_handles = self.send_forward(model_chunk_id, output_obj) else: - input_obj = self.send_forward_recv_forward( + input_obj, fwd_wait_handles = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), output_tensor=output_obj, - send_prior=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0, ) if num_microbatch_remaining > 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) + output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): - last_iteration = i == num_microbatch_remaining - 1 + fwd_batch_id = i + num_warmup_microbatch + last_batch = i == num_microbatch_remaining - 1 + model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True) - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + # Wait for input. + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) @@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule): # Pop output_obj and output_obj from the start of the list for the backward pass. _input_obj = input_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0) - input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) - # NOTE: perform 2x communication for forward and backward - def send_forward_recv_backward(): - if last_iteration and num_microbatch == num_microbatch_remaining: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) - self.send_forward(model_chunk_id, output_obj) + # Helper functions + def send_forward_recv_forward(): + if last_batch: + model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True) + wait_handles = self.send_forward(model_chunk_id, output_obj) + return None, wait_handles else: - output_obj_grad = self.send_forward_recv_backward( - model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), - model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + input_obj, wait_handles = self.send_forward_recv_forward( + model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True), output_tensor=output_obj, - send_prior_fallback=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0 + and i > 0, # Receive from warmup stage first in the first batch ) - return output_obj_grad + return input_obj, wait_handles - def send_backward_recv_forward(): - if last_iteration: + def send_backward_recv_backward(): + no_cooldown = num_microbatch == num_microbatch_remaining + if last_batch and no_cooldown: model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) + wait_handles = self.send_backward(model_chunk_id, input_obj_grad) + return None, wait_handles else: - input_obj = self.send_backward_recv_forward( + output_obj_grad, wait_handles = self.send_backward_recv_backward( model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), - model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), input_tensor_grad=input_obj_grad, - send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0, + send_first=self.stage_manager.stage % 2 == 0, ) - return input_obj + return output_obj_grad, wait_handles - if self.stage_manager.stage % 2 == 0: - output_obj_grad = send_forward_recv_backward() - input_obj = send_backward_recv_forward() - else: - input_obj = send_backward_recv_forward() - output_obj_grad = send_forward_recv_backward() + input_obj, fwd_wait_handles = send_forward_recv_forward() + # Wait for upstream grad + _wait_p2p(bwd_wait_handles) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) + # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) + # however in practice this works fine, and Megatron does this too + # (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774) + # if deadlock, call _wait_p2p(fwd_wait_handles) here + output_obj_grad, bwd_wait_handles = send_backward_recv_backward() if num_microbatch_remaining == 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) + output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id) + # Run cooldown backward passes. for i in range(num_microbatch_remaining, num_microbatch): - last_iteration = i == num_microbatch - 1 + last_batch = i == num_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=False) _input_obj = input_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0) - # output_obj_grad = self.recv_backward(model_chunk_id) - input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) - if not last_iteration: - output_obj_grad = self.send_backward_recv_backward( + # Wait for upstream grad + _wait_p2p(bwd_wait_handles) + # backward local grads + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + if not last_batch: + output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward( model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), input_tensor_grad=input_obj_grad, - send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, + send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, ) + assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0 else: model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) + _ = self.send_backward(model_chunk_id, input_obj_grad) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index bfea8b67d..7f0d0e349 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): num_microbatches is not None or microbatch_size is not None ), "Either num_microbatches or microbatch_size should be provided" - self.comm = PipelineP2PCommunication(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) + self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size self.batch: Optional[Any] = None @@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input tensor or input tensor list. """ if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) @@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input gradient tensor or gradient tensor list. """ if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) @@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.send_grad_metadata = not self.enable_metadata_cache - def send_forward_recv_backward( - self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None - ) -> Any: + def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. @@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): if not self.send_tensor_metadata and self.grad_metadata_recv is not None: - send_prior_fallback = None # must not fallback - output_tensor_grad = self.comm.send_forward_recv_backward( + send_first = None + output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor, - next_rank, send_metadata=self.send_tensor_metadata, metadata_recv=self.grad_metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, ) self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: @@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): return output_tensor_grad - def send_backward_recv_forward( - self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None - ) -> Any: + def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. @@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_first_stage(): if not self.send_grad_metadata and self.tensor_metadata_recv is not None: - send_prior_fallback = None # must not fallback - input_tensor = self.comm.send_backward_recv_forward( + send_first = None # must not fallback + input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor_grad, - prev_rank, send_metadata=self.send_grad_metadata, metadata_recv=self.tensor_metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, ) self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: @@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - output_obj_grad = self.send_forward_recv_backward( - output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0 - ) + output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_backward(input_obj_grad) else: input_obj = self.send_backward_recv_forward( - input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0 + input_obj_grad, send_first=self.stage_manager.stage % 2 == 0 ) # Run cooldown backward passes. diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b7cbd67ab..354f110f0 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -35,7 +35,7 @@ class PipelineStageManager: self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None - self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {} if num_layers_per_stage is not None: assert len(num_layers_per_stage) == self.num_stages self.num_layers_per_stage = num_layers_per_stage @@ -48,30 +48,14 @@ class PipelineStageManager: # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - - # init p2p process groups - stages = list(range(self.num_stages)) - for prev, cur in zip(stages[:-1], stages[1:]): - group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) - if self.stage in [prev, cur]: - ranks_in_group = self.pg_mesh.get_ranks_in_group(group) - self.p2p_groups[tuple(ranks_in_group)] = group - self.is_interleave = enable_interleave # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks - if enable_interleave: - # use circle p2p communication - # add the process group of the first rank and the last rank - group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) - if self.stage in [stages[0], stages[-1]]: - ranks_in_group = self.pg_mesh.get_ranks_in_group(group) - self.p2p_groups[tuple(ranks_in_group)] = group - - # for shardformer, hold stage indices of model - self.stage_indices: List[Tuple[int, int]] - # for shardformer, hold model chunk id - self.model_chunk_id: Optional[int] = None + # for shardformer, hold stage indices of model + self.stage_indices: List[Tuple[int, int]] + # for shardformer, hold model chunk id + self.model_chunk_id: Optional[int] = None + self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis) def get_stage_index( self, @@ -184,19 +168,12 @@ class PipelineStageManager: """ return self.next_rank - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: + def get_p2p_process_group(self) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. - - Args: - first_rank (int): The first rank. - second_rank (int): The second rank. - Returns: ProcessGroup: P2P process group between the two ranks. """ - if first_rank > second_rank: - first_rank, second_rank = second_rank, first_rank - return self.p2p_groups[(first_rank, second_rank)] + return self.p2p_group def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: """Get the process group of the given stages. diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index f6c975305..4b897770e 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -1,9 +1,11 @@ import argparse import resource import time +import warnings from contextlib import nullcontext import torch +import torch.distributed as dist from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context @@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import PipelineGradientCheckpointConfig +warnings.filterwarnings("ignore") # ============================== # Constants # ============================== MODEL_CONFIGS = { + "100m": LlamaConfig( + max_position_embeddings=4096, + num_hidden_layers=4, + num_attention_heads=32, + intermediate_size=2048, + hidden_size=1024, + ), "7b": LlamaConfig(max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, @@ -58,6 +68,9 @@ def main(): default="gemini", help="Choose which plugin to use", ) + parser.add_argument( + "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel." + ) parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") @@ -78,11 +91,13 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) - parser.add_argument( - "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False - ) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -98,6 +113,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", } if args.custom_ckpt else {} @@ -174,6 +190,8 @@ def main(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, zero_stage=args.zero, sp_size=args.sp, enable_sequence_parallelism=args.sp > 1, @@ -182,12 +200,16 @@ def main(): microbatch_size=args.mbs, precision="bf16", dp_outside=False, + overlap_p2p=args.overlap, + enable_metadata_cache=not args.no_cache, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), @@ -195,6 +217,7 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", + overlap_p2p=args.overlap, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -210,10 +233,11 @@ def main(): config = MODEL_CONFIGS[args.config] else: config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) - dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) # ============================== # Initialize Model and Optimizer @@ -251,6 +275,7 @@ def main(): optimizer = HybridAdam(model.parameters()) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) coordinator.print_on_master( f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" @@ -269,15 +294,19 @@ def main(): data_iter = iter(dataloader) for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): performance_evaluator.on_step_start(step) - booster.execute_pipeline( + outputs = booster.execute_pipeline( data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, - return_loss=False, + return_loss=True, ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) prof.step() else: @@ -288,6 +317,7 @@ def main(): booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) prof.step() diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 48a8d12e0..30b557f5e 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -15,8 +15,7 @@ WORLD_SIZE = 2 def check_p2p_communication(): pg_mesh = ProcessGroupMesh(WORLD_SIZE) stage_manager = PipelineStageManager(pg_mesh, 0) - p2p = PipelineP2PCommunication(stage_manager) - + p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False) rank = dist.get_rank() tensor = torch.ones(1, device=get_accelerator().get_current_device()) @@ -31,41 +30,40 @@ def check_p2p_communication(): for obj in data: p2p.send_forward(obj) for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False) + recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False) assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: - recv_obj = p2p.recv_forward() + recv_obj, _ = p2p.recv_forward() assert recv_obj == obj for i in range(len(data)): p2p.send_backward(data[-(i + 1)]) - recv_obj = p2p.recv_forward() + recv_obj, _ = p2p.recv_forward() assert recv_obj == data[i] if rank == 1: for obj in data: p2p.send_backward(obj) for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True) + recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True) assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: - recv_obj = p2p.recv_backward() + recv_obj, _ = p2p.recv_backward() assert recv_obj == obj for i in range(len(data)): - recv_obj = p2p.recv_backward() - p2p.send_forward(data[-(i + 1)]) + recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False) assert recv_obj == data[i] if rank == 0: - recv_obj = p2p.send_forward_recv_backward( + recv_obj, _ = p2p.send_forward_recv_backward( tensor, send_metadata=False, metadata_recv=create_send_metadata(tensor), ) assert recv_obj == tensor elif rank == 1: - recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor)) + recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor)) assert recv_obj == tensor p2p.send_backward(tensor, send_metadata=False) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 5146a86c8..a3793013b 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -52,7 +52,7 @@ def check_stage_manager(): # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: - group = stage_manager.get_p2p_process_group(prev, cur) + group = stage_manager.get_p2p_process_group() dist.barrier(group=group) # check stage groups