[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
pull/5207/head
Wenhao Chen 2023-12-22 10:44:00 +08:00 committed by GitHub
parent af952673f7
commit 4fa689fca1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 728 additions and 446 deletions

View File

@ -1,4 +1,5 @@
import ctypes import ctypes
import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
if os.getenv("NCCL_BUFFSIZE") is None:
logger = get_dist_logger()
logger.warning(
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
)
os.environ["NCCL_BUFFSIZE"] = "134217728"
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert ( assert (

View File

@ -4,13 +4,13 @@
import io import io
import pickle import pickle
import re import re
from typing import Any, List, Optional, Union
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from packaging.version import Version from packaging.version import Version
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
@ -20,7 +20,7 @@ from .stage_manager import PipelineStageManager
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
"""transform tensor to object with unpickle. """transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling Info of the device in bytes stream will be modified into current device before unpickling
@ -48,21 +48,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle return unpickle
def check_for_nccl_backend(group): # NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
def _broadcast_object_list( def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
): ):
@ -70,13 +56,11 @@ def _broadcast_object_list(
The only difference is that object will be move to correct device after unpickled. The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src. be updated with data sent from rank src.
Args: Args:
object_list (List[Any]): list of object to broadcast object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast src (int): source rank to broadcast
dst (int): dst rank to broadcast dst (int): dst rank to broadcast
device (:class:`torch.device`): device to do broadcast. current device in default device (:class:`torch.device`): device to do broadcast. current device in default
""" """
if c10d._rank_not_in_group(group): if c10d._rank_not_in_group(group):
@ -149,6 +133,18 @@ def _broadcast_object_list(
object_list[i] = unpickle_object object_list[i] = unpickle_object
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def check_device(group): def check_device(group):
is_nccl_backend = check_for_nccl_backend(group) is_nccl_backend = check_for_nccl_backend(group)
current_device = None current_device = None
@ -159,14 +155,14 @@ def check_device(group):
return current_device, is_nccl_backend return current_device, is_nccl_backend
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
class P2PDataType(Enum): class P2PDataType(Enum):
serialization = 0 Serialization = 0
tensor = 1 Tensor = 1
list = 2 List = 2
dict = 3 Dict = 3
@dataclass @dataclass
@ -175,45 +171,71 @@ class P2PMetadata:
content: Union[List[TensorMetadata], TensorMetadata, Any] content: Union[List[TensorMetadata], TensorMetadata, Any]
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
obj = obj.contiguous() obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
ops_queue.append(op_to_add) ops_queue.append(op_to_add)
else: else:
for tensor_to_comm in obj: for tensor_to_comm in obj:
tensor_to_comm = tensor_to_comm.contiguous() assert isinstance(tensor_to_comm, torch.Tensor)
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
ops_queue.append(op_to_add)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device): def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
if p2p_metadata.data_type == P2PDataType.tensor: if p2p_metadata.data_type == P2PDataType.Tensor:
metadata = p2p_metadata.content metadata = p2p_metadata.content
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
return tensor_recv return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict): elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
buffer_recv = [] buffer_recv = []
for metadata in p2p_metadata.content: for metadata in p2p_metadata.content:
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
buffer_recv.append(tensor_recv) buffer_recv.append(tensor_recv)
return buffer_recv return buffer_recv
else: else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): def create_fast_send_metadata(object: Any) -> P2PMetadata:
assert _check_if_fast_send_available(object)
if isinstance(object, torch.Tensor):
data_type = P2PDataType.Tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif isinstance(object, list):
data_type = P2PDataType.List
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
elif isinstance(object, dict):
data_type = P2PDataType.Dict
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
else:
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
return P2PMetadata(data_type, content)
def _batch_send_recv_tensor(
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
recv_tensor_metadata: Optional[P2PMetadata],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None buffer_recv = None
if recv_tensor_metadata is not None: if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
ops = [] ops = []
if send_dst is not None and send_tensor_list is not None:
if send_dst is not None: assert send_group is not None
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
if recv_src is not None: assert recv_group is not None
assert buffer_recv is not None
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
@ -221,24 +243,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# Remove synchronization according to Pytorch's documentation # Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here # However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()` # In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize() torch.cuda.synchronize()
return buffer_recv return buffer_recv
def _send_recv_serialization_object( def _send_recv_serialization_object(
object: Any, object: Any,
send_dst: Optional[int], recv_src: Optional[int], send_dst: Optional[int],
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], recv_src: Optional[int],
current_device, send_group: Optional[ProcessGroup],
is_nccl_backend): recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
) -> Optional[P2PMetadata]:
ops = [] ops = []
send_object_tensor = None send_object_tensor = None
if object is not None and send_dst is not None: if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"): if Version(torch.__version__) >= Version("1.13.0"):
@ -264,10 +288,8 @@ def _send_recv_serialization_object(
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize() torch.cuda.synchronize()
ops = [] ops = []
@ -286,52 +308,77 @@ def _send_recv_serialization_object(
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize() torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None: if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8) recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"): if recv_object_tensor.device != torch.device("cpu"):
recv_object_tensor = recv_object_tensor.cpu() recv_object_tensor = recv_object_tensor.cpu()
unpickle_object = _cuda_safe_tensor_to_object( unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
recv_object_tensor, recv_object_size_tensor.item())
if ( if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda() unpickle_object = unpickle_object.cuda()
return unpickle_object return unpickle_object
def _check_if_fast_send_available(object): def _check_if_fast_send_available(object: Any) -> bool:
if type(object) is torch.Tensor: if isinstance(object, torch.Tensor):
return True return True
elif type(object) is list: elif isinstance(object, list):
is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
return is_list_of_tensor return is_list_of_tensor
elif type(object) is dict: elif isinstance(object, dict):
is_dict_of_tensor = all([type(k) is str and type( is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
v) is torch.Tensor for k, v in object.items()])
return is_dict_of_tensor return is_dict_of_tensor
return False return False
def _communicate( def _communicate(
object, object: Any,
send_dst: Optional[int], send_dst: Optional[int],
recv_src: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup] = None, send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Any: ) -> Any:
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): """
c10d._warn_not_in_group("_communicate") Send and receive object from send_dst and recv_src respectively
return
Args:
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
metadata_recv (P2PMetadata, optional): metadata of the object to be received
"""
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
assert (
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
), "metadata_recv type must not be Serialization"
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = check_device(send_group) current_send_device, is_send_nccl_backend = check_device(send_group)
current_recv_device, is_recv_nccl_backend = check_device(recv_group) current_recv_device, is_recv_nccl_backend = check_device(recv_group)
@ -341,67 +388,56 @@ def _communicate(
assert current_send_device == current_recv_device assert current_send_device == current_recv_device
current_device = current_send_device current_device = current_send_device
assert (send_dst is not None) or (recv_src is not None) if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
metadata_send = None
can_fast_send = False if send_dst is not None and send_metadata:
send_metadata = None
if send_dst is not None:
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
if not can_fast_send: if not can_fast_send:
send_metadata = P2PMetadata(P2PDataType.serialization, object) metadata_send = P2PMetadata(P2PDataType.Serialization, object)
else: else:
if type(object) is torch.Tensor: metadata_send = create_fast_send_metadata(object)
data_type = P2PDataType.tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif type(object) is list:
data_type = P2PDataType.list
content = []
for v in object:
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
elif type(object) is dict:
data_type = P2PDataType.dict
content = []
for k, v in object.items():
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
else:
raise ValueError('Cannot send object of type {}'.format(type(object)))
send_metadata = P2PMetadata(data_type, content)
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) # Send and receive metadata
if recv_metadata is not None: _metadata_recv = _send_recv_serialization_object(
assert type(recv_metadata) is P2PMetadata object=metadata_send,
if recv_metadata.data_type == P2PDataType.serialization: send_dst=send_dst if send_metadata else None,
return recv_metadata.content recv_src=recv_src if metadata_recv is None else None,
if not can_fast_send and send_dst is not None: send_group=send_group if send_metadata else None,
return recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
)
assert metadata_recv is None or _metadata_recv is None
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
send_tensor_list = None send_tensor_list = None
if type(object) is torch.Tensor: if isinstance(object, torch.Tensor):
send_tensor_list = object send_tensor_list = object
elif type(object) is list: elif isinstance(object, list):
send_tensor_list = object send_tensor_list = object
elif type(object) is dict: elif isinstance(object, dict):
send_tensor_list = list(object.values()) send_tensor_list = list(object.values())
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device) # Send and receive data
recv_buffer = _batch_send_recv_tensor(
if recv_metadata is not None: send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
assert recv_buffer is not None
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
return recv_buffer
elif recv_metadata.data_type == P2PDataType.dict:
return {
k: v
for k, v in zip(
[m.key for m in recv_metadata.content],
recv_buffer,
) )
}
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
if metadata_recv.data_type == P2PDataType.Serialization:
return metadata_recv.content
else: else:
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type)) assert recv_buffer is not None
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
return recv_buffer
elif metadata_recv.data_type == P2PDataType.Dict:
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
else:
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
"""send anything to dst rank """send anything to dst rank
Args: Args:
@ -411,10 +447,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns: Returns:
None None
""" """
_communicate(object, send_dst=dst, recv_src=None, send_group=group) _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
"""recv anything from src """recv anything from src
Args: Args:
@ -423,7 +459,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns: Returns:
Any: Object received from src. Any: Object received from src.
""" """
return _communicate(None, send_dst=None, recv_src=src, recv_group=group) return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
def _p2p_comm( def _p2p_comm(
@ -436,7 +472,7 @@ def _p2p_comm(
""" """
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Agrs: Args:
tensor_send_next (torch.Tensor): tensor to be sent to next stage tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer peer (int): rank of the peer
@ -467,7 +503,6 @@ def _p2p_comm(
group=group, group=group,
) )
ops.append(recv_prev_op) ops.append(recv_prev_op)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
@ -490,7 +525,6 @@ def _p2p_comm(
group=group, group=group,
) )
ops.append(send_next_op) ops.append(send_next_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp( recv_prev_op = dist.P2POp(
dist.irecv, dist.irecv,
@ -510,7 +544,7 @@ class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None: def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager self.stage_manager = stage_manager
def recv_forward(self, prev_rank: int = None) -> Any: def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args: Args:
@ -522,11 +556,13 @@ class PipelineP2PCommunication:
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) input_tensor = _recv_object(
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
)
return input_tensor return input_tensor
def recv_backward(self, next_rank: int = None) -> Any: def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args: Args:
@ -539,12 +575,12 @@ class PipelineP2PCommunication:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object( output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
) )
return output_tensor_grad return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None: def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
Args: Args:
@ -554,9 +590,15 @@ class PipelineP2PCommunication:
if next_rank is None: if next_rank is None:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) _send_object(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata,
)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
Args: Args:
@ -566,9 +608,21 @@ class PipelineP2PCommunication:
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) _send_object(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata,
)
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: def send_forward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
Args: Args:
@ -581,11 +635,22 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate( return _communicate(
input_object, next_rank, next_rank, input_object,
send_group=group, recv_group=group, next_rank,
next_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
) )
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args: Args:
@ -597,37 +662,22 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
return _communicate(
input_object, prev_rank, prev_rank,
send_group=group, recv_group=group,
)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender of the tensor
next_rank (int, optional): The rank of the recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate( return _communicate(
input_object, input_object,
send_dst=next_rank, prev_rank,
recv_src=prev_rank, prev_rank,
send_group=send_group, send_group=group,
recv_group=recv_group, recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
) )
def p2p_communicate( def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 self,
output_object: Any,
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None: ) -> None:
""" """
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@ -636,10 +686,14 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if peer is None: if next_rank is None:
peer = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm( recv_tensor = _p2p_comm(
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
comm_dtype,
) )
return recv_tensor return recv_tensor

View File

@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
assert ( assert (
num_microbatch is not None or microbatch_size is not None num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided" ), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatch = num_microbatch self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
self.batch: Any self.batch: Any
self.batch_size: int self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int] self.microbatch_offset: List[int]
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None: if self.num_microbatch is not None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
@ -106,10 +119,11 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
input_tensor = None if not self.stage_manager.is_first_stage():
else: input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
input_tensor = self.comm.recv_forward(prev_rank) if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor_grad = None if not self.stage_manager.is_last_stage():
else: output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
output_tensor_grad = self.comm.recv_backward(next_rank) if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_forward(output_object, next_rank) if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_backward(input_object, prev_rank) if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
def forward_step( def forward_step(
self, self,
@ -180,7 +233,7 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
self.stage_manager.model_chunk_id = model_chunk_id with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList): if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else: else:
@ -188,9 +241,8 @@ class InterleavedSchedule(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
self.stage_manager.model_chunk_id = None
if self.stage_manager.is_last_stage(model_chunk_id): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
# TODO: handle arbitrary batch size when forward_only == True self.forward_only = not torch.is_grad_enabled()
forward_only = not torch.is_grad_enabled()
if optimizer is None: if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward." assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks num_microbatch = self.num_microbatch * self.num_model_chunks
if forward_only: if self.forward_only:
num_warmup_microbatch = num_microbatch num_warmup_microbatch = num_microbatch
else: else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
@ -288,42 +339,28 @@ class InterleavedSchedule(PipelineSchedule):
input_objs = None input_objs = None
output_objs = None output_objs = None
if not forward_only: if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)] input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
if return_loss and self.stage_manager.is_last_stage(-1):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
# for ranks except the first one, get into recv state accum_loss = torch.zeros(1, device=get_current_device())
input_obj = self.recv_forward(0)
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatch): for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True) model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# recv first on first rank to avoid sending or receiving at the same time
if self.stage_manager.is_first_stage(-1):
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj) if not self.forward_only:
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch: if num_microbatch_remaining > 0:
break model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
# Run 1F1B in steady state. # Run 1F1B in steady state.
@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
last_iteration = i == num_microbatch_remaining - 1 last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only: if self.forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.recv_forward(model_chunk_id) input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
else: else:
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
# backward # backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if last_iteration: if not last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch): for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0) input_obj = input_objs[model_chunk_id].pop(0)
@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
if not forward_only: if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided" ), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter) batch = next(data_iter)
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0 self.microbatch_offset = 0
if not self._use_microbatch_size: if not self._use_microbatch_size:
assert ( assert (
@ -92,10 +106,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = None input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
else: if self.metadata_recv_forward is None:
input_tensor = self.comm.recv_forward(prev_rank) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -109,10 +123,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = None output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
else: if self.metadata_recv_backward is None:
output_tensor_grad = self.comm.recv_backward(next_rank) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank) self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank) self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor. prev_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank) input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: return input_tensor
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
def forward_step( def forward_step(
self, self,
@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None: if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward." assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_objs = None input_objs = None
output_objs = None output_objs = None
if not forward_only: if not self.forward_only:
input_objs = [] input_objs = []
output_objs = [] output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
else: outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = self.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj) self.send_forward(output_obj)
if not forward_only: if not self.forward_only:
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
if self.forward_only:
self.send_forward(output_obj) self.send_forward(output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.recv_forward() input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
else:
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration: if last_iteration:
input_obj = None
else:
input_obj = self.recv_forward()
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not self.forward_only:
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
if isinstance(model, ModelWrapper): if isinstance(model, ModelWrapper):
model = model.unwrap() model = model.unwrap()

View File

@ -1,3 +1,4 @@
import contextlib
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch.distributed as dist import torch.distributed as dist
@ -68,45 +69,39 @@ class PipelineStageManager:
# for shardformer, hold model chunk id # for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None self.model_chunk_id: Optional[int] = None
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage. """Is the current stage the first stage.
NOTE: NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() 2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
if self.is_interleave and model_chunk_id is None: assert isinstance(ignore_chunk, bool)
model_chunk_id = self.model_chunk_id assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
assert self.is_interleave ^ ( if not self.is_interleave or ignore_chunk:
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == 0 return self.stage == 0
else: else:
return self.stage == 0 and model_chunk_id == 0 return self.stage == 0 and self.model_chunk_id == 0
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: def is_last_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
NOTE: NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() 2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
if self.is_interleave and model_chunk_id is None: assert isinstance(ignore_chunk, bool)
model_chunk_id = self.model_chunk_id assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
assert self.is_interleave ^ ( if not self.is_interleave or ignore_chunk:
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
else: else:
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property @property
def num_stages(self) -> int: def num_stages(self) -> int:
@ -174,3 +169,10 @@ class PipelineStageManager:
ProcessGroup: Process group of the given stages. ProcessGroup: Process group of the given stages.
""" """
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
@contextlib.contextmanager
def switch_model_chunk_id(self, model_chunk_id: int):
old_model_chunk_id = self.model_chunk_id
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id

View File

@ -309,11 +309,11 @@ class BertPolicy(Policy):
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages, num_stages=stage_manager.num_stages,
) )
if stage_manager.is_first_stage(-1): if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx]) held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(-1): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.pooler) held_layers.append(module.pooler)
else: else:
@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
"""Get pipeline layers for current stage""" """Get pipeline layers for current stage"""
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs) held_layers.append(self.model.qa_outputs)
return held_layers return held_layers

View File

@ -8,7 +8,11 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
@ -140,21 +144,42 @@ class LlamaPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface """If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy.""" to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager: if self.pipeline_stage_manager is None:
return
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel": if self.model.__class__.__name__ == "LlamaModel":
module = self.model module = self.model
else: else:
module = self.model.model module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)} method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
) )
return self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
@ -167,6 +192,25 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
suffix="lm_head", target_module=Linear1D_Col
)
], ],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
) )
} }
policy.update(new_item) policy.update(new_item)
@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score) held_layers.append(self.model.score)
return held_layers return held_layers

View File

@ -57,9 +57,7 @@ def evaluate_model(
def evaluate_subset(dataloader: DataLoader): def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
None if not booster.plugin.stage_manager.is_interleave else -1
)
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader: for batch in dataloader:
@ -136,9 +134,7 @@ def train_epoch(
coordinator: DistCoordinator, coordinator: DistCoordinator,
): ):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
None if not booster.plugin.stage_manager.is_interleave else -1
)
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
total_step = len(train_dataloader) total_step = len(train_dataloader)

View File

@ -133,7 +133,9 @@ def main():
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
pp_style="interleaved",
zero_stage=args.zero, zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs, num_microbatches=args.mbs,
precision="bf16", precision="bf16",

View File

@ -1,47 +1,80 @@
import warnings
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
WORLD_SIZE = 2
def check_p2p_communication(): def check_p2p_communication():
pg_mesh = ProcessGroupMesh(2) pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0) stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager) p2p = PipelineP2PCommunication(stage_manager)
rank = dist.get_rank() rank = dist.get_rank()
tensor = torch.ones(1, device=get_current_device()) tensor = torch.ones(1, device=get_current_device())
data = [
"tensor",
tensor,
[tensor],
{"tensor": tensor},
]
if rank == 0: if rank == 0:
p2p.send_forward(tensor) for obj in data:
p2p.send_forward([tensor]) p2p.send_forward(obj)
p2p.send_forward({"tensor": tensor}) for i in range(len(data)):
else: recv_obj = p2p.send_forward_recv_backward(data[i])
obj = p2p.recv_forward() assert recv_obj == data[-(i + 1)]
assert torch.equal(obj, tensor) elif rank == 1:
obj = p2p.recv_forward() for obj in data:
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) recv_obj = p2p.recv_forward()
obj = p2p.recv_forward() assert recv_obj == obj
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1: if rank == 1:
p2p.send_backward(tensor) for obj in data:
p2p.send_backward([tensor]) p2p.send_backward(obj)
p2p.send_backward({"tensor": tensor}) for i in range(len(data)):
else: recv_obj = p2p.send_backward_recv_forward(data[i])
obj = p2p.recv_backward() assert recv_obj == data[-(i + 1)]
assert torch.equal(obj, tensor) elif rank == 0:
obj = p2p.recv_backward() for obj in data:
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) recv_obj = p2p.recv_backward()
obj = p2p.recv_backward() assert recv_obj == obj
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
warnings.filterwarnings("error")
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=comm_metadata,
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pipeline_p2p(): def test_pipeline_p2p():
spawn(run_dist, 2) spawn(run_dist, WORLD_SIZE)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -37,9 +37,10 @@ def pp_linear_fwd(
stage_mgr: PipelineStageManager = None, stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None, model_chunk_id: int = None,
): ):
if stage_mgr.is_first_stage(model_chunk_id): with stage_mgr.switch_model_chunk_id(model_chunk_id):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage(model_chunk_id): elif stage_mgr.is_last_stage():
return forward(input_obj) return forward(input_obj)
else: else:
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
@ -107,7 +108,7 @@ def run_pp(
) )
# check loss # check loss
if stage_manager.is_last_stage(-1): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
@ -119,6 +120,7 @@ def run_pp(
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param # check updated param
for i in range(num_model_chunk): for i in range(num_model_chunk):
@ -126,6 +128,24 @@ def run_pp(
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12]) @pytest.mark.parametrize("num_microbatch", [4, 12])

View File

@ -4,6 +4,7 @@ from types import MethodType
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
DIM = 8
NUM_LAYER = 8
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super().__init__()
self.linear1 = nn.Linear(4, 8) self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
self.linear2 = nn.Linear(8, 4)
def forward(self, x): def forward(self, x):
x = self.linear1(x) for layer in self.layers:
x = self.linear2(x) x = layer(x)
return x return x
def pp_linear_fwd( def pp_linear_fwd(
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
): ):
if stage_mgr.is_first_stage(): if stage_mgr.is_first_stage():
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
@ -38,34 +44,45 @@ def pp_linear_fwd(
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
def examine_pp(): def examine_pp(num_microbatch: int, batch_size: int):
""" """
This test is to examine the correctness of 1F1B, compared with torch. This test is to examine the correctness of 1F1B, compared with torch.
Be aware it contains some hardcodes. Be aware it contains some hardcodes.
""" """
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
local_rank = torch.distributed.get_rank() dist.get_rank()
seed_all(1453) seed_all(1453)
NUM_MICRO_BATCHS = 4
BATCH_SIZE = 4
# create models # create models
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda() pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(world_size)
pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()): rank = dist.get_rank()
if idx % (world_size) == local_rank: sharded_model = torch.nn.ModuleList()
sharded_model = sub_model.cuda() num_local_layer = NUM_LAYER // world_size
for idx, sub_model in enumerate(pp_model.layers):
if idx // num_local_layer == rank:
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_local_layer
sharded_model._forward = sharded_model.forward def custom_fwd(self, x):
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) for layer in self._modules.values():
x = layer(x)
return x
sharded_model._forward = MethodType(custom_fwd, sharded_model)
sharded_model.forward = MethodType(
partial(
pp_linear_fwd,
stage_mgr=stage_manager,
),
sharded_model._forward,
)
# create optimizer # create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@ -73,19 +90,15 @@ def examine_pp():
# create # create
seed_all(1453) seed_all(1453)
if stage_manager.is_first_stage(): input_list = [torch.rand(batch_size, DIM).cuda()]
input_list = [torch.rand(BATCH_SIZE, 4).cuda()] dist.all_reduce(input_list[0])
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x) criterion = lambda x, *arg, **kwargs: (x * x).mean()
# forward and backward # forward and backward
torch_output = torch_model(input_list[0]) torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
) )
@ -95,34 +108,66 @@ def examine_pp():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
torch_grad = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_grad.append(torch_p.grad.data) assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param # check updated param
torch_param = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_param.append(torch_p.data) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(rank, world_size, port): def run_dist(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp() examine_pp(num_microbatch, batch_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp(): def test_pp(num_microbatch: int, batch_size: int, world_size: int):
spawn(run_dist, 2) assert NUM_LAYER % world_size == 0
spawn(
run_dist,
world_size,
num_microbatch=num_microbatch,
batch_size=batch_size,
)
if __name__ == "__main__": if __name__ == "__main__":
test_pp() test_pp(num_microbatch=4, batch_size=4, world_size=4)

View File

@ -203,7 +203,7 @@ def check_output_hidden_state(
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
@ -229,6 +229,10 @@ def check_weight(
org_weight = getattr_(org_model, suffix).weight org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [ sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))

View File

@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"] col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
) )
@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3 atol, rtol = 5e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
@ -183,6 +184,17 @@ def run_bert_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_bert_3d_test(test_config): def run_bert_3d_test(test_config):

View File

@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4 atol, rtol = 1e-6, 1e-4
else: else:
@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights # check weights
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
@ -179,6 +179,17 @@ def run_llama_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_llama_3d_test(test_config): def run_llama_3d_test(test_config):