From 4fa689fca1ecf50b8e905cf1c74d4a2c08219daf Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 22 Dec 2023 10:44:00 +0800 Subject: [PATCH] [pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134) * test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin --- .../booster/plugin/hybrid_parallel_plugin.py | 11 +- colossalai/pipeline/p2p.py | 396 ++++++++++-------- .../pipeline/schedule/interleaved_pp.py | 189 +++++---- colossalai/pipeline/schedule/one_f_one_b.py | 135 +++--- colossalai/pipeline/stage_manager.py | 38 +- colossalai/shardformer/policies/bert.py | 20 +- colossalai/shardformer/policies/llama.py | 86 +++- examples/language/bert/finetune.py | 8 +- examples/language/llama2/benchmark.py | 2 + tests/test_pipeline/test_p2p_communication.py | 79 +++- .../test_schedule/test_interleaved.py | 34 +- .../test_schedule/test_oneF_oneB.py | 133 ++++-- tests/test_shardformer/test_model/_utils.py | 6 +- .../test_model/test_shard_bert.py | 20 +- .../test_model/test_shard_llama.py | 17 +- 15 files changed, 728 insertions(+), 446 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 91fcba55a..ea74f75f4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,4 +1,5 @@ import ctypes +import os import random from contextlib import contextmanager from functools import partial @@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: + if os.getenv("NCCL_BUFFSIZE") is None: + logger = get_dist_logger() + logger.warning( + "Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen." + ) + os.environ["NCCL_BUFFSIZE"] = "134217728" + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 6e49fa36b..cdb7a6a1e 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -4,13 +4,13 @@ import io import pickle import re -from typing import Any, List, Optional, Union from collections import namedtuple +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, List, Optional, Union import torch import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d @@ -20,7 +20,7 @@ from .stage_manager import PipelineStageManager _unpickler = pickle.Unpickler -def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any: """transform tensor to object with unpickle. Info of the device in bytes stream will be modified into current device before unpickling @@ -48,21 +48,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle -def check_for_nccl_backend(group): - pg = group or c10d._get_default_group() - # Gate PG wrapper check on Gloo availability. - if c10d._GLOO_AVAILABLE: - # It is not expected for PG to be wrapped many times, but support it just - # in case - while isinstance(pg, c10d._ProcessGroupWrapper): - pg = pg.wrapped_pg - - return ( - c10d.is_nccl_available() and - pg.name() == c10d.Backend.NCCL - ) - - +# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None ): @@ -70,13 +56,11 @@ def _broadcast_object_list( The only difference is that object will be move to correct device after unpickled. If local_rank = src, then object list will be sent to rank src. Otherwise, object list will be updated with data sent from rank src. - Args: object_list (List[Any]): list of object to broadcast src (int): source rank to broadcast dst (int): dst rank to broadcast device (:class:`torch.device`): device to do broadcast. current device in default - """ if c10d._rank_not_in_group(group): @@ -131,7 +115,7 @@ def _broadcast_object_list( if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset: offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -149,6 +133,18 @@ def _broadcast_object_list( object_list[i] = unpickle_object +def check_for_nccl_backend(group): + pg = group or c10d._get_default_group() + # Gate PG wrapper check on Gloo availability. + if c10d._GLOO_AVAILABLE: + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, c10d._ProcessGroupWrapper): + pg = pg.wrapped_pg + + return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL + + def check_device(group): is_nccl_backend = check_for_nccl_backend(group) current_device = None @@ -159,14 +155,14 @@ def check_device(group): return current_device, is_nccl_backend -TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) +TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"]) class P2PDataType(Enum): - serialization = 0 - tensor = 1 - list = 2 - dict = 3 + Serialization = 0 + Tensor = 1 + List = 2 + Dict = 3 @dataclass @@ -175,45 +171,71 @@ class P2PMetadata: content: Union[List[TensorMetadata], TensorMetadata, Any] -def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): +def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup): if isinstance(obj, torch.Tensor): obj = obj.contiguous() op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) ops_queue.append(op_to_add) else: for tensor_to_comm in obj: - tensor_to_comm = tensor_to_comm.contiguous() - op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) - ops_queue.append(op_to_add) + assert isinstance(tensor_to_comm, torch.Tensor) + filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group) -def create_recv_buffer(p2p_metadata: P2PMetadata, current_device): - if p2p_metadata.data_type == P2PDataType.tensor: +def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any): + if p2p_metadata.data_type == P2PDataType.Tensor: metadata = p2p_metadata.content - tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + tensor_recv = torch.empty( + metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype + ) return tensor_recv - elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict): + elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict): buffer_recv = [] for metadata in p2p_metadata.content: - tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + tensor_recv = torch.empty( + metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype + ) buffer_recv.append(tensor_recv) return buffer_recv else: raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") -def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): +def create_fast_send_metadata(object: Any) -> P2PMetadata: + assert _check_if_fast_send_available(object) + if isinstance(object, torch.Tensor): + data_type = P2PDataType.Tensor + content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) + elif isinstance(object, list): + data_type = P2PDataType.List + content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object] + elif isinstance(object, dict): + data_type = P2PDataType.Dict + content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()] + else: + raise RuntimeError("Cannot handle object of type {}".format(type(object))) + return P2PMetadata(data_type, content) + + +def _batch_send_recv_tensor( + send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]], + recv_tensor_metadata: Optional[P2PMetadata], + send_dst: Optional[int], + recv_src: Optional[int], + send_group: Optional[ProcessGroup], + recv_group: Optional[ProcessGroup], + current_device: Any, +) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: buffer_recv = None - if recv_tensor_metadata is not None: + if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization: buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) ops = [] - - if send_dst is not None: + if send_dst is not None and send_tensor_list is not None: + assert send_group is not None filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) - - if recv_src is not None: - assert buffer_recv is not None + if recv_src is not None and buffer_recv is not None: + assert recv_group is not None filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: @@ -221,24 +243,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re for req in reqs: req.wait() - torch.cuda.synchronize() - # Remove synchronization according to Pytorch's documentation # However, the Megatron-LM does synchronization here # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # In case there is potential error, uncomment the following `torch.cuda.synchronize()` - # torch.cuda.synchronize() + torch.cuda.synchronize() return buffer_recv def _send_recv_serialization_object( - object: Any, - send_dst: Optional[int], recv_src: Optional[int], - send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], - current_device, - is_nccl_backend): + object: Any, + send_dst: Optional[int], + recv_src: Optional[int], + send_group: Optional[ProcessGroup], + recv_group: Optional[ProcessGroup], + current_device: Any, + is_nccl_backend: bool, +) -> Optional[P2PMetadata]: ops = [] + send_object_tensor = None if object is not None and send_dst is not None: if Version(torch.__version__) >= Version("1.13.0"): @@ -264,10 +288,8 @@ def _send_recv_serialization_object( for req in reqs: req.wait() - torch.cuda.synchronize() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + torch.cuda.synchronize() ops = [] @@ -286,52 +308,77 @@ def _send_recv_serialization_object( for req in reqs: req.wait() - torch.cuda.synchronize() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + torch.cuda.synchronize() if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) if recv_object_tensor.device != torch.device("cpu"): recv_object_tensor = recv_object_tensor.cpu() - unpickle_object = _cuda_safe_tensor_to_object( - recv_object_tensor, recv_object_size_tensor.item()) + unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item()) - if ( - isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() - ): + if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): unpickle_object = unpickle_object.cuda() return unpickle_object -def _check_if_fast_send_available(object): - if type(object) is torch.Tensor: +def _check_if_fast_send_available(object: Any) -> bool: + if isinstance(object, torch.Tensor): return True - elif type(object) is list: - is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) + elif isinstance(object, list): + is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object]) return is_list_of_tensor - elif type(object) is dict: - is_dict_of_tensor = all([type(k) is str and type( - v) is torch.Tensor for k, v in object.items()]) - + elif isinstance(object, dict): + is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()]) return is_dict_of_tensor return False def _communicate( - object, + object: Any, send_dst: Optional[int], recv_src: Optional[int], send_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, ) -> Any: - if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): - c10d._warn_not_in_group("_communicate") - return + """ + Send and receive object from send_dst and recv_src respectively + + Args: + object (Any): object needed to be sent + send_dst (int): rank of the destination + recv_src (int): rank of the source + send_group (ProcessGroup, optional): process group of sender + recv_group (ProcessGroup, optional): process group of receiver + send_metadata (bool, optional): whether to send metadata + metadata_recv (P2PMetadata, optional): metadata of the object to be received + """ + assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None" + assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None" + assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None" + send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object)) + assert ( + metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization + ), "metadata_recv type must not be Serialization" + + # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, + # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. + if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + + # NOTE: only the following 5 cases are valid: + # 1. send() [needs extra metadata] and no recv() + # 2. recv() [needs extra metadata] and no send() + # 3. neither send() nor recv() need extra metadata + assert not (send_dst is not None and send_metadata) or recv_src is None + assert not (recv_src is not None and metadata_recv is None) or send_dst is None + assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) + assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) current_send_device, is_send_nccl_backend = check_device(send_group) current_recv_device, is_recv_nccl_backend = check_device(recv_group) @@ -341,67 +388,56 @@ def _communicate( assert current_send_device == current_recv_device current_device = current_send_device - assert (send_dst is not None) or (recv_src is not None) - - can_fast_send = False - send_metadata = None - if send_dst is not None: - can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend - if not can_fast_send: - send_metadata = P2PMetadata(P2PDataType.serialization, object) - else: - if type(object) is torch.Tensor: - data_type = P2PDataType.tensor - content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) - elif type(object) is list: - data_type = P2PDataType.list - content = [] - for v in object: - content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad)) - elif type(object) is dict: - data_type = P2PDataType.dict - content = [] - for k, v in object.items(): - content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad)) + if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None): + metadata_send = None + if send_dst is not None and send_metadata: + can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend + if not can_fast_send: + metadata_send = P2PMetadata(P2PDataType.Serialization, object) else: - raise ValueError('Cannot send object of type {}'.format(type(object))) - send_metadata = P2PMetadata(data_type, content) + metadata_send = create_fast_send_metadata(object) - recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) - if recv_metadata is not None: - assert type(recv_metadata) is P2PMetadata - if recv_metadata.data_type == P2PDataType.serialization: - return recv_metadata.content - if not can_fast_send and send_dst is not None: - return + # Send and receive metadata + _metadata_recv = _send_recv_serialization_object( + object=metadata_send, + send_dst=send_dst if send_metadata else None, + recv_src=recv_src if metadata_recv is None else None, + send_group=send_group if send_metadata else None, + recv_group=recv_group if metadata_recv is None else None, + current_device=current_device, + is_nccl_backend=is_nccl_backend, + ) + assert metadata_recv is None or _metadata_recv is None + metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv send_tensor_list = None - if type(object) is torch.Tensor: + if isinstance(object, torch.Tensor): send_tensor_list = object - elif type(object) is list: + elif isinstance(object, list): send_tensor_list = object - elif type(object) is dict: + elif isinstance(object, dict): send_tensor_list = list(object.values()) - recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device) + # Send and receive data + recv_buffer = _batch_send_recv_tensor( + send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device + ) - if recv_metadata is not None: - assert recv_buffer is not None - if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]: - return recv_buffer - elif recv_metadata.data_type == P2PDataType.dict: - return { - k: v - for k, v in zip( - [m.key for m in recv_metadata.content], - recv_buffer, - ) - } + if metadata_recv is not None: + assert isinstance(metadata_recv, P2PMetadata) + if metadata_recv.data_type == P2PDataType.Serialization: + return metadata_recv.content else: - raise ValueError('Unknown data type {}'.format(recv_metadata.data_type)) + assert recv_buffer is not None + if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]: + return recv_buffer + elif metadata_recv.data_type == P2PDataType.Dict: + return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)} + else: + raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None: """send anything to dst rank Args: @@ -411,10 +447,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group) + _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) -def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: +def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any: """recv anything from src Args: @@ -423,7 +459,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group) + return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv) def _p2p_comm( @@ -436,7 +472,7 @@ def _p2p_comm( """ Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. - Agrs: + Args: tensor_send_next (torch.Tensor): tensor to be sent to next stage recv_prev (bool): whether to receive tensor from previous stage peer (int): rank of the peer @@ -467,7 +503,6 @@ def _p2p_comm( group=group, ) ops.append(recv_prev_op) - if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: @@ -490,7 +525,6 @@ def _p2p_comm( group=group, ) ops.append(send_next_op) - if tensor_recv_prev is not None: recv_prev_op = dist.P2POp( dist.irecv, @@ -510,7 +544,7 @@ class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager - def recv_forward(self, prev_rank: int = None) -> Any: + def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -522,11 +556,13 @@ class PipelineP2PCommunication: if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + input_tensor = _recv_object( + prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv + ) return input_tensor - def recv_backward(self, next_rank: int = None) -> Any: + def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -539,12 +575,12 @@ class PipelineP2PCommunication: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() output_tensor_grad = _recv_object( - next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) + next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv ) return output_tensor_grad - def send_forward(self, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None: """Sends the input tensor to the next stage in pipeline. Args: @@ -554,9 +590,15 @@ class PipelineP2PCommunication: if next_rank is None: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + _send_object( + output_object, + cur_rank, + next_rank, + self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + send_metadata, + ) - def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: """Sends the gradient tensor to the previous stage in pipeline. Args: @@ -566,9 +608,21 @@ class PipelineP2PCommunication: if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + _send_object( + input_object, + cur_rank, + prev_rank, + self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), + send_metadata, + ) - def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: + def send_forward_recv_backward( + self, + input_object: Any, + next_rank: Optional[int] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline Args: @@ -581,11 +635,22 @@ class PipelineP2PCommunication: cur_rank = self.stage_manager.get_rank() group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) return _communicate( - input_object, next_rank, next_rank, - send_group=group, recv_group=group, + input_object, + next_rank, + next_rank, + send_group=group, + recv_group=group, + send_metadata=send_metadata, + metadata_recv=metadata_recv, ) - def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: + def send_backward_recv_forward( + self, + input_object: Any, + prev_rank: Optional[int] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline Args: @@ -597,37 +662,22 @@ class PipelineP2PCommunication: cur_rank = self.stage_manager.get_rank() group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) - return _communicate( - input_object, prev_rank, prev_rank, - send_group=group, recv_group=group, - ) - - def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: - """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. - - Args: - input_object (Any): Object to be sent. - prev_rank (int, optional): The rank of the sender of the tensor - next_rank (int, optional): The rank of the recipient of the tensor - """ - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - - cur_rank = self.stage_manager.get_rank() - recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) - send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) return _communicate( input_object, - send_dst=next_rank, - recv_src=prev_rank, - send_group=send_group, - recv_group=recv_group, + prev_rank, + prev_rank, + send_group=group, + recv_group=group, + send_metadata=send_metadata, + metadata_recv=metadata_recv, ) def p2p_communicate( - self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + self, + output_object: Any, + recv_pre: bool, + next_rank: Optional[int] = None, + comm_dtype: torch.dtype = torch.float16, ) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -636,10 +686,14 @@ class PipelineP2PCommunication: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if peer is None: - peer = self.stage_manager.get_next_rank() + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() recv_tensor = _p2p_comm( - output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + output_object, + recv_pre, + next_rank, + self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + comm_dtype, ) return recv_tensor diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 7c3f15e80..3c8b00977 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule): assert ( num_microbatch is not None or microbatch_size is not None ), "Either num_microbatch or microbatch_size should be provided" + self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size @@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule): self.batch: Any self.batch_size: int + self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] + # P2PMeta cache + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule): batch = tree_map(partial(to_device, device=device), batch) self.batch = batch self.batch_size = get_batch_size(batch) + if self.last_batch_size is None: + self.last_batch_size = self.batch_size + else: + assert self.forward_only or self.last_batch_size == self.batch_size + # TODO: support arbitrary batch size when forward_only=True self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] if self.num_microbatch is not None: assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" @@ -106,12 +119,13 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(model_chunk_id): - input_tensor = None - else: - input_tensor = self.comm.recv_forward(prev_rank) + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - return input_tensor + return input_tensor def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. @@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(model_chunk_id): - output_tensor_grad = None - else: - output_tensor_grad = self.comm.recv_backward(next_rank) + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - return output_tensor_grad + return output_tensor_grad - def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule): output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(model_chunk_id): - self.comm.send_forward(output_object, next_rank) + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) + self.send_metadata_forward = False - def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule): input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(model_chunk_id): - self.comm.send_backward(input_object, prev_rank) + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) + self.send_metadata_backward = False + + def send_forward_recv_backward( + self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None + ) -> Any: + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + ) + self.send_metadata_forward = False + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + + return output_tensor_grad + + def send_backward_recv_forward( + self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None + ) -> Any: + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + ) + self.send_metadata_backward = False + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + + return input_tensor def forward_step( self, @@ -180,25 +233,24 @@ class InterleavedSchedule(PipelineSchedule): # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - self.stage_manager.model_chunk_id = model_chunk_id - if isinstance(model_chunk, ModuleList): - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) - else: - # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers - internal_inputs = {} if input_obj is None else input_obj - internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - self.stage_manager.model_chunk_id = None + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - if self.stage_manager.is_last_stage(model_chunk_id): - loss = criterion(output_obj, micro_batch) / self.num_microbatch - if accum_loss is not None: - accum_loss.add_(loss.detach()) - if outputs is not None: - outputs.append(tree_map(detach, output_obj)) - return loss - else: - return output_obj + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj def backward_step( self, @@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule): Returns: dict: A dict with keys: 'loss' and 'outputs'. """ - # TODO: handle arbitrary batch size when forward_only == True - forward_only = not torch.is_grad_enabled() + self.forward_only = not torch.is_grad_enabled() if optimizer is None: - assert forward_only, "Optimizer should be passed when doing backward." + assert self.forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) num_microbatch = self.num_microbatch * self.num_model_chunks - if forward_only: + if self.forward_only: num_warmup_microbatch = num_microbatch else: num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 @@ -288,43 +339,29 @@ class InterleavedSchedule(PipelineSchedule): input_objs = None output_objs = None - if not forward_only: + if not self.forward_only: input_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None - if return_loss and self.stage_manager.is_last_stage(-1): + accum_loss = None + if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - - # for ranks except the first one, get into recv state - input_obj = self.recv_forward(0) # Run warmup forward passes. for i in range(num_warmup_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - # recv first on first rank to avoid sending or receiving at the same time - if self.stage_manager.is_first_stage(-1): - input_obj = self.recv_forward(model_chunk_id) - output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) - else: - output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) - self.send_forward(model_chunk_id, output_obj) + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not self.forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) - if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch: - break - - model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + if num_microbatch_remaining > 0: + model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): @@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule): last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if forward_only: - self.send_forward(model_chunk_id, output_obj) - + if self.forward_only: if not last_iteration: - input_obj = self.recv_forward(model_chunk_id) + input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj) + else: + self.send_forward(model_chunk_id, output_obj) else: self.send_forward(model_chunk_id, output_obj) @@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule): # backward input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) - if last_iteration: - input_obj = None - else: + if not last_iteration: model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) - # Run cooldown backward passes. - if not forward_only: + if not self.forward_only: for i in range(num_microbatch_remaining, num_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=False) input_obj = input_objs[model_chunk_id].pop(0) @@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule): input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) - if not forward_only: + if not self.forward_only: assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index fd918cf19..8c161efec 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert ( num_microbatches is not None or microbatch_size is not None ), "Either num_microbatches or microbatch_size should be provided" + self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None + self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self._use_microbatch_size = num_microbatches is None + # P2PMeta cache + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch self.batch_size = get_batch_size(batch) + if self.last_batch_size is None: + self.last_batch_size = self.batch_size + else: + assert self.forward_only or self.last_batch_size == self.batch_size + # TODO: support arbitrary batch size when forward_only=True self.microbatch_offset = 0 if not self._use_microbatch_size: assert ( @@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - input_tensor = self.comm.recv_forward(prev_rank) + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - return input_tensor + return input_tensor def recv_backward(self, next_rank: int = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. @@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - output_tensor_grad = self.comm.recv_backward(next_rank) + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - return output_tensor_grad + return output_tensor_grad def send_forward(self, output_object: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. @@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank) - - def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: - """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. - For 1F1B. - - Args: - output_object (Any): Object to be sent. - next_rank (int, optional): The rank of the recipient of the tensor. - """ - if not self.stage_manager.is_last_stage(): - return self.comm.send_forward_recv_backward(output_object, next_rank) + self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) + self.send_metadata_forward = False def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank) + self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) + self.send_metadata_backward = False + + def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + ) + self.send_metadata_forward = False + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + + return output_tensor_grad def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. @@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_first_stage(): - return self.comm.send_backward_recv_forward(output_object, prev_rank) + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + ) + self.send_metadata_backward = False + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: - """Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. - For 1F1B. - - Args: - input_object (Any): Object to be sent. - prev_rank (int, optional): The previous rank of the recipient of the tensor. - next_rank (int, optional): The next rank of the recipient of the tensor. - """ - if self.stage_manager.is_first_stage(): - return self.comm.send_forward(input_object, next_rank) - elif self.stage_manager.is_last_stage(): - return self.comm.recv_forward(prev_rank) - else: - return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank) + return input_tensor def forward_step( self, @@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Returns: dict: A dict with keys: 'loss' and 'outputs'. """ - forward_only = not torch.is_grad_enabled() + + self.forward_only = not torch.is_grad_enabled() if optimizer is None: - assert forward_only, "Optimizer should be passed when doing backward." + assert self.forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) @@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_objs = None output_objs = None - if not forward_only: + if not self.forward_only: input_objs = [] output_objs = [] - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + accum_loss = None if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None # Run warmup forward passes. for i in range(num_warmup_microbatches): input_obj = self.recv_forward() - output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.send_forward(output_obj) - if not forward_only: + if not self.forward_only: input_objs.append(input_obj) output_objs.append(output_obj) @@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - if forward_only: + + if self.forward_only: self.send_forward(output_obj) if not last_iteration: input_obj = self.recv_forward() - else: - # TODO adjust here - self.send_forward(output_obj) - output_obj_grad = self.recv_backward() + else: + output_obj_grad = self.send_forward_recv_backward(output_obj) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: - input_obj = None + self.send_backward(input_obj_grad) else: - input_obj = self.recv_forward() - self.send_backward(input_obj_grad) + input_obj = self.send_backward_recv_forward(input_obj_grad) # Run cooldown backward passes. - if not forward_only: + if not self.forward_only: for i in range(num_warmup_microbatches): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) @@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(input_obj_grad) + if not self.forward_only: + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + if outputs is not None: if isinstance(model, ModelWrapper): model = model.unwrap() diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index d7853938a..c8f904208 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,3 +1,4 @@ +import contextlib from typing import Dict, List, Optional, Tuple import torch.distributed as dist @@ -68,45 +69,39 @@ class PipelineStageManager: # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None - def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: + def is_first_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the first stage. NOTE: 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. - 2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() + 2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device() Returns: bool: Whether the current stage is the first stage. """ - if self.is_interleave and model_chunk_id is None: - model_chunk_id = self.model_chunk_id - assert self.is_interleave ^ ( - model_chunk_id is None - ), "model_chunk_id must be specified when using interleaved pipeline" - if not self.is_interleave or model_chunk_id < 0: + assert isinstance(ignore_chunk, bool) + assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) + if not self.is_interleave or ignore_chunk: return self.stage == 0 else: - return self.stage == 0 and model_chunk_id == 0 + return self.stage == 0 and self.model_chunk_id == 0 - def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: + def is_last_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the last stage. NOTE: 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. - 2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() + 2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device() Returns: bool: Whether the current stage is the last stage. """ - if self.is_interleave and model_chunk_id is None: - model_chunk_id = self.model_chunk_id - assert self.is_interleave ^ ( - model_chunk_id is None - ), "model_chunk_id must be specified when using interleaved pipeline" - if not self.is_interleave or model_chunk_id < 0: + assert isinstance(ignore_chunk, bool) + assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) + if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: @@ -174,3 +169,10 @@ class PipelineStageManager: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) + + @contextlib.contextmanager + def switch_model_chunk_id(self, model_chunk_id: int): + old_model_chunk_id = self.model_chunk_id + self.model_chunk_id = model_chunk_id + yield + self.model_chunk_id = old_model_chunk_id diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 78363bf5e..0ab63b765 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -309,11 +309,11 @@ class BertPolicy(Policy): num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages, ) - if stage_manager.is_first_stage(-1): + if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(-1): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.pooler) else: @@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy): """Get pipeline layers for current stage""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.qa_outputs) return held_layers diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index eee2259f2..39a4d4023 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -8,7 +8,11 @@ from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D -from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy +from ..modeling.llama import ( + LlamaPipelineForwards, + get_llama_flash_attention_forward, + get_lm_forward_with_dist_cross_entropy, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] @@ -140,21 +144,42 @@ class LlamaPolicy(Policy): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "LlamaModel": - module = self.model - else: - module = self.model.model + if self.pipeline_stage_manager is None: + return + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)} + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) - return + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -167,13 +192,32 @@ class LlamaPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers @@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy): new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col - ) + SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) ], - method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } policy.update(new_item) @@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers @@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b349d7edf..aad12c9c2 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,9 +57,7 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( - None if not booster.plugin.stage_manager.is_interleave else -1 - ) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: @@ -136,9 +134,7 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( - None if not booster.plugin.stage_manager.is_interleave else -1 - ) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index daf7d2fd4..a4c29b7c8 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -133,7 +133,9 @@ def main(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style="interleaved", zero_stage=args.zero, + num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), num_microbatches=args.mbs, precision="bf16", diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 1665711ce..40b6ac8eb 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -1,47 +1,80 @@ +import warnings + import pytest import torch import torch.distributed as dist import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device +WORLD_SIZE = 2 + def check_p2p_communication(): - pg_mesh = ProcessGroupMesh(2) + pg_mesh = ProcessGroupMesh(WORLD_SIZE) stage_manager = PipelineStageManager(pg_mesh, 0) p2p = PipelineP2PCommunication(stage_manager) rank = dist.get_rank() tensor = torch.ones(1, device=get_current_device()) + data = [ + "tensor", + tensor, + [tensor], + {"tensor": tensor}, + ] if rank == 0: - p2p.send_forward(tensor) - p2p.send_forward([tensor]) - p2p.send_forward({"tensor": tensor}) - else: - obj = p2p.recv_forward() - assert torch.equal(obj, tensor) - obj = p2p.recv_forward() - assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) - obj = p2p.recv_forward() - assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + for obj in data: + p2p.send_forward(obj) + for i in range(len(data)): + recv_obj = p2p.send_forward_recv_backward(data[i]) + assert recv_obj == data[-(i + 1)] + elif rank == 1: + for obj in data: + recv_obj = p2p.recv_forward() + assert recv_obj == obj + for i in range(len(data)): + p2p.send_backward(data[-(i + 1)]) + recv_obj = p2p.recv_forward() + assert recv_obj == data[i] if rank == 1: - p2p.send_backward(tensor) - p2p.send_backward([tensor]) - p2p.send_backward({"tensor": tensor}) - else: - obj = p2p.recv_backward() - assert torch.equal(obj, tensor) - obj = p2p.recv_backward() - assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) - obj = p2p.recv_backward() - assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + for obj in data: + p2p.send_backward(obj) + for i in range(len(data)): + recv_obj = p2p.send_backward_recv_forward(data[i]) + assert recv_obj == data[-(i + 1)] + elif rank == 0: + for obj in data: + recv_obj = p2p.recv_backward() + assert recv_obj == obj + for i in range(len(data)): + recv_obj = p2p.recv_backward() + p2p.send_forward(data[-(i + 1)]) + assert recv_obj == data[i] + + warnings.filterwarnings("error") + tensor_metadata = TensorMetadata( + key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad + ) + comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata) + if rank == 0: + recv_obj = p2p.send_forward_recv_backward( + tensor, + send_metadata=False, + metadata_recv=comm_metadata, + ) + assert recv_obj == tensor + elif rank == 1: + recv_obj = p2p.recv_forward(metadata_recv=comm_metadata) + assert recv_obj == tensor + p2p.send_backward(tensor, send_metadata=False) def run_dist(rank, world_size, port): @@ -52,7 +85,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_pipeline_p2p(): - spawn(run_dist, 2) + spawn(run_dist, WORLD_SIZE) if __name__ == "__main__": diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 4de50245f..0e81818eb 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -37,12 +37,13 @@ def pp_linear_fwd( stage_mgr: PipelineStageManager = None, model_chunk_id: int = None, ): - if stage_mgr.is_first_stage(model_chunk_id): - return {"input_obj": forward(data)} - elif stage_mgr.is_last_stage(model_chunk_id): - return forward(input_obj) - else: - return {"input_obj": forward(input_obj)} + with stage_mgr.switch_model_chunk_id(model_chunk_id): + if stage_mgr.is_first_stage(): + return {"input_obj": forward(data)} + elif stage_mgr.is_last_stage(): + return forward(input_obj) + else: + return {"input_obj": forward(input_obj)} def run_pp( @@ -107,7 +108,7 @@ def run_pp( ) # check loss - if stage_manager.is_last_stage(-1): + if stage_manager.is_last_stage(ignore_chunk=True): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients @@ -119,6 +120,7 @@ def run_pp( # step torch_optimizer.step() pp_optimizer.step() + pp_optimizer.zero_grad() # check updated param for i in range(num_model_chunk): @@ -126,6 +128,24 @@ def run_pp( assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) + if stage_manager.is_last_stage(ignore_chunk=True): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + @pytest.mark.dist @pytest.mark.parametrize("num_microbatch", [4, 12]) diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 1d77edc2d..5f27be396 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -4,6 +4,7 @@ from types import MethodType import pytest import torch +import torch.distributed as dist import torch.nn as nn import colossalai @@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all +DIM = 8 +NUM_LAYER = 8 + class MlpModel(nn.Module): def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(4, 8) - self.linear2 = nn.Linear(8, 4) + super().__init__() + self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) + for layer in self.layers: + x = layer(x) return x def pp_linear_fwd( - forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None + forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, ): if stage_mgr.is_first_stage(): return {"input_obj": forward(data)} @@ -38,34 +44,45 @@ def pp_linear_fwd( return {"input_obj": forward(input_obj)} -def examine_pp(): +def examine_pp(num_microbatch: int, batch_size: int): """ This test is to examine the correctness of 1F1B, compared with torch. Be aware it contains some hardcodes. """ - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + dist.get_rank() seed_all(1453) - NUM_MICRO_BATCHS = 4 - BATCH_SIZE = 4 - # create models torch_model = MlpModel().cuda() pp_model = copy.deepcopy(torch_model).cuda() - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch) - for idx, (_, sub_model) in enumerate(pp_model.named_children()): - if idx % (world_size) == local_rank: - sharded_model = sub_model.cuda() + rank = dist.get_rank() + sharded_model = torch.nn.ModuleList() + num_local_layer = NUM_LAYER // world_size + for idx, sub_model in enumerate(pp_model.layers): + if idx // num_local_layer == rank: + sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_local_layer - sharded_model._forward = sharded_model.forward - sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) + def custom_fwd(self, x): + for layer in self._modules.values(): + x = layer(x) + return x + + sharded_model._forward = MethodType(custom_fwd, sharded_model) + sharded_model.forward = MethodType( + partial( + pp_linear_fwd, + stage_mgr=stage_manager, + ), + sharded_model._forward, + ) # create optimizer torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -73,19 +90,15 @@ def examine_pp(): # create seed_all(1453) - if stage_manager.is_first_stage(): - input_list = [torch.rand(BATCH_SIZE, 4).cuda()] - else: - input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] - torch.distributed.all_reduce(input_list[0]) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) - criterion = lambda x, y: torch.mean(x) + criterion = lambda x, *arg, **kwargs: (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output, _) + torch_loss = criterion(torch_output) torch_loss.backward() - pp_ret = schedule.forward_backward_step( sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True ) @@ -95,34 +108,66 @@ def examine_pp(): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients - torch_grad = [] - for torch_p in torch_model.parameters(): - torch_grad.append(torch_p.grad.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + for i in range(len(sharded_model)): + idx = rank * num_local_layer + i + assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() pp_optimizer.step() + pp_optimizer.zero_grad() # check updated param - torch_param = [] - for torch_p in torch_model.parameters(): - torch_param.append(torch_p.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + for i in range(len(sharded_model)): + idx = rank * num_local_layer + i + assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) + assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) -def run_dist(rank, world_size, port): +def run_dist( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, +): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - examine_pp() + examine_pp(num_microbatch, batch_size) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("batch_size", [12]) +@pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() -def test_pp(): - spawn(run_dist, 2) +def test_pp(num_microbatch: int, batch_size: int, world_size: int): + assert NUM_LAYER % world_size == 0 + spawn( + run_dist, + world_size, + num_microbatch=num_microbatch, + batch_size=batch_size, + ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, world_size=4) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 6acbe4ff5..87e661802 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -203,7 +203,7 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): + if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: sharded_hidden_state = sharded_output.last_hidden_state @@ -229,6 +229,10 @@ def check_weight( org_weight = getattr_(org_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight + # skip if layer is not held by this process + if sharded_weight is None: + continue + if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index b38793b7c..768bd95bd 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] col_layer_for_check = ["encoder.layer[0].output.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: col_layer_grads = get_grad_tensors_for_check( bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) @@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: @@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check grads check_all_grad_tensors(grads_to_check) @@ -183,6 +184,17 @@ def run_bert_test(test_config): "zero_stage": 1, "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, ], ) def run_bert_3d_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f8f08e1d0..c7edcfb35 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-4 else: @@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: @@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if stage_manager is None or stage_manager.is_first_stage(): + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: @@ -179,6 +179,17 @@ def run_llama_test(test_config): "zero_stage": 1, "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, ], ) def run_llama_3d_test(test_config):