mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPluginpull/5207/head
parent
af952673f7
commit
4fa689fca1
|
@ -1,4 +1,5 @@
|
|||
import ctypes
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
if os.getenv("NCCL_BUFFSIZE") is None:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
|
||||
)
|
||||
os.environ["NCCL_BUFFSIZE"] = "134217728"
|
||||
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
import io
|
||||
import pickle
|
||||
import re
|
||||
from typing import Any, List, Optional, Union
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
|
@ -20,7 +20,7 @@ from .stage_manager import PipelineStageManager
|
|||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
||||
"""transform tensor to object with unpickle.
|
||||
Info of the device in bytes stream will be modified into current device before unpickling
|
||||
|
||||
|
@ -48,21 +48,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||
return unpickle
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
c10d.is_nccl_available() and
|
||||
pg.name() == c10d.Backend.NCCL
|
||||
)
|
||||
|
||||
|
||||
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
):
|
||||
|
@ -70,13 +56,11 @@ def _broadcast_object_list(
|
|||
The only difference is that object will be move to correct device after unpickled.
|
||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||
be updated with data sent from rank src.
|
||||
|
||||
Args:
|
||||
object_list (List[Any]): list of object to broadcast
|
||||
src (int): source rank to broadcast
|
||||
dst (int): dst rank to broadcast
|
||||
device (:class:`torch.device`): device to do broadcast. current device in default
|
||||
|
||||
"""
|
||||
|
||||
if c10d._rank_not_in_group(group):
|
||||
|
@ -131,7 +115,7 @@ def _broadcast_object_list(
|
|||
|
||||
if my_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset: offset + obj_size]
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
|
@ -149,6 +133,18 @@ def _broadcast_object_list(
|
|||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def check_device(group):
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
@ -159,14 +155,14 @@ def check_device(group):
|
|||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
|
||||
|
||||
|
||||
class P2PDataType(Enum):
|
||||
serialization = 0
|
||||
tensor = 1
|
||||
list = 2
|
||||
dict = 3
|
||||
Serialization = 0
|
||||
Tensor = 1
|
||||
List = 2
|
||||
Dict = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -175,45 +171,71 @@ class P2PMetadata:
|
|||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
||||
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
tensor_to_comm = tensor_to_comm.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
assert isinstance(tensor_to_comm, torch.Tensor)
|
||||
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||
|
||||
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
||||
if p2p_metadata.data_type == P2PDataType.tensor:
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
|
||||
if p2p_metadata.data_type == P2PDataType.Tensor:
|
||||
metadata = p2p_metadata.content
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
return tensor_recv
|
||||
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
||||
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
|
||||
buffer_recv = []
|
||||
for metadata in p2p_metadata.content:
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
||||
def create_fast_send_metadata(object: Any) -> P2PMetadata:
|
||||
assert _check_if_fast_send_available(object)
|
||||
if isinstance(object, torch.Tensor):
|
||||
data_type = P2PDataType.Tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif isinstance(object, list):
|
||||
data_type = P2PDataType.List
|
||||
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
|
||||
elif isinstance(object, dict):
|
||||
data_type = P2PDataType.Dict
|
||||
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
|
||||
else:
|
||||
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
|
||||
return P2PMetadata(data_type, content)
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(
|
||||
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
recv_tensor_metadata: Optional[P2PMetadata],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
|
||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None:
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if recv_src is not None:
|
||||
assert buffer_recv is not None
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
|
@ -221,24 +243,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
|
|||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
object: Any,
|
||||
send_dst: Optional[int], recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
||||
current_device,
|
||||
is_nccl_backend):
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
|
@ -264,10 +288,8 @@ def _send_recv_serialization_object(
|
|||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ops = []
|
||||
|
||||
|
@ -286,52 +308,77 @@ def _send_recv_serialization_object(
|
|||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
recv_object_tensor = recv_object_tensor.cpu()
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(
|
||||
recv_object_tensor, recv_object_size_tensor.item())
|
||||
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
return unpickle_object
|
||||
|
||||
|
||||
def _check_if_fast_send_available(object):
|
||||
if type(object) is torch.Tensor:
|
||||
def _check_if_fast_send_available(object: Any) -> bool:
|
||||
if isinstance(object, torch.Tensor):
|
||||
return True
|
||||
elif type(object) is list:
|
||||
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
||||
elif isinstance(object, list):
|
||||
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
|
||||
return is_list_of_tensor
|
||||
elif type(object) is dict:
|
||||
is_dict_of_tensor = all([type(k) is str and type(
|
||||
v) is torch.Tensor for k, v in object.items()])
|
||||
|
||||
elif isinstance(object, dict):
|
||||
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
|
||||
return is_dict_of_tensor
|
||||
return False
|
||||
|
||||
|
||||
def _communicate(
|
||||
object,
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
||||
c10d._warn_not_in_group("_communicate")
|
||||
return
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
metadata_recv (P2PMetadata, optional): metadata of the object to be received
|
||||
"""
|
||||
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
||||
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
||||
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
||||
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
|
||||
assert (
|
||||
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
|
||||
), "metadata_recv type must not be Serialization"
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||
|
@ -341,67 +388,56 @@ def _communicate(
|
|||
assert current_send_device == current_recv_device
|
||||
current_device = current_send_device
|
||||
|
||||
assert (send_dst is not None) or (recv_src is not None)
|
||||
|
||||
can_fast_send = False
|
||||
send_metadata = None
|
||||
if send_dst is not None:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
||||
else:
|
||||
if type(object) is torch.Tensor:
|
||||
data_type = P2PDataType.tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif type(object) is list:
|
||||
data_type = P2PDataType.list
|
||||
content = []
|
||||
for v in object:
|
||||
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
||||
elif type(object) is dict:
|
||||
data_type = P2PDataType.dict
|
||||
content = []
|
||||
for k, v in object.items():
|
||||
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
||||
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
||||
metadata_send = None
|
||||
if send_dst is not None and send_metadata:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
|
||||
else:
|
||||
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
||||
send_metadata = P2PMetadata(data_type, content)
|
||||
metadata_send = create_fast_send_metadata(object)
|
||||
|
||||
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
||||
if recv_metadata is not None:
|
||||
assert type(recv_metadata) is P2PMetadata
|
||||
if recv_metadata.data_type == P2PDataType.serialization:
|
||||
return recv_metadata.content
|
||||
if not can_fast_send and send_dst is not None:
|
||||
return
|
||||
# Send and receive metadata
|
||||
_metadata_recv = _send_recv_serialization_object(
|
||||
object=metadata_send,
|
||||
send_dst=send_dst if send_metadata else None,
|
||||
recv_src=recv_src if metadata_recv is None else None,
|
||||
send_group=send_group if send_metadata else None,
|
||||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
send_tensor_list = None
|
||||
if type(object) is torch.Tensor:
|
||||
if isinstance(object, torch.Tensor):
|
||||
send_tensor_list = object
|
||||
elif type(object) is list:
|
||||
elif isinstance(object, list):
|
||||
send_tensor_list = object
|
||||
elif type(object) is dict:
|
||||
elif isinstance(object, dict):
|
||||
send_tensor_list = list(object.values())
|
||||
|
||||
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
||||
# Send and receive data
|
||||
recv_buffer = _batch_send_recv_tensor(
|
||||
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
|
||||
)
|
||||
|
||||
if recv_metadata is not None:
|
||||
assert recv_buffer is not None
|
||||
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
||||
return recv_buffer
|
||||
elif recv_metadata.data_type == P2PDataType.dict:
|
||||
return {
|
||||
k: v
|
||||
for k, v in zip(
|
||||
[m.key for m in recv_metadata.content],
|
||||
recv_buffer,
|
||||
)
|
||||
}
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
if metadata_recv.data_type == P2PDataType.Serialization:
|
||||
return metadata_recv.content
|
||||
else:
|
||||
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
||||
assert recv_buffer is not None
|
||||
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
|
||||
return recv_buffer
|
||||
elif metadata_recv.data_type == P2PDataType.Dict:
|
||||
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
|
||||
else:
|
||||
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
|
@ -411,10 +447,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
|
@ -423,7 +459,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
|||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
|
@ -436,7 +472,7 @@ def _p2p_comm(
|
|||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
Args:
|
||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||
recv_prev (bool): whether to receive tensor from previous stage
|
||||
peer (int): rank of the peer
|
||||
|
@ -467,7 +503,6 @@ def _p2p_comm(
|
|||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
|
@ -490,7 +525,6 @@ def _p2p_comm(
|
|||
group=group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
|
@ -510,7 +544,7 @@ class PipelineP2PCommunication:
|
|||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -522,11 +556,13 @@ class PipelineP2PCommunication:
|
|||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
|
||||
input_tensor = _recv_object(
|
||||
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -539,12 +575,12 @@ class PipelineP2PCommunication:
|
|||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
|
@ -554,9 +590,15 @@ class PipelineP2PCommunication:
|
|||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
|
||||
_send_object(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
send_metadata,
|
||||
)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
|
@ -566,9 +608,21 @@ class PipelineP2PCommunication:
|
|||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
_send_object(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
send_metadata,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
Args:
|
||||
|
@ -581,11 +635,22 @@ class PipelineP2PCommunication:
|
|||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object, next_rank, next_rank,
|
||||
send_group=group, recv_group=group,
|
||||
input_object,
|
||||
next_rank,
|
||||
next_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
|
@ -597,37 +662,22 @@ class PipelineP2PCommunication:
|
|||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
return _communicate(
|
||||
input_object, prev_rank, prev_rank,
|
||||
send_group=group, recv_group=group,
|
||||
)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender of the tensor
|
||||
next_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=send_group,
|
||||
recv_group=recv_group,
|
||||
prev_rank,
|
||||
prev_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
self,
|
||||
output_object: Any,
|
||||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
@ -636,10 +686,14 @@ class PipelineP2PCommunication:
|
|||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
|
@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
# P2PMeta cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
if self.last_batch_size is None:
|
||||
self.last_batch_size = self.batch_size
|
||||
else:
|
||||
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||
# TODO: support arbitrary batch size when forward_only=True
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
if self.num_microbatch is not None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
|
@ -106,12 +119,13 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage(model_chunk_id):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage(model_chunk_id):
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = False
|
||||
|
||||
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = False
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
)
|
||||
self.send_metadata_forward = False
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
)
|
||||
self.send_metadata_backward = False
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
|
@ -180,25 +233,24 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
|
||||
self.stage_manager.model_chunk_id = model_chunk_id
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
self.stage_manager.model_chunk_id = None
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(
|
||||
self,
|
||||
|
@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
# TODO: handle arbitrary batch size when forward_only == True
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
if forward_only:
|
||||
if self.forward_only:
|
||||
num_warmup_microbatch = num_microbatch
|
||||
else:
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
|
@ -288,43 +339,29 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
output_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
if return_loss and self.stage_manager.is_last_stage(-1):
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# for ranks except the first one, get into recv state
|
||||
input_obj = self.recv_forward(0)
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
# recv first on first rank to avoid sending or receiving at the same time
|
||||
if self.stage_manager.is_first_stage(-1):
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not self.forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
|
||||
break
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
|
@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
last_iteration = i == num_microbatch_remaining - 1
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if self.forward_only:
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# backward
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
if not last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
|
@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self._use_microbatch_size = num_microbatches is None
|
||||
|
||||
# P2PMeta cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
if self.last_batch_size is None:
|
||||
self.last_batch_size = self.batch_size
|
||||
else:
|
||||
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||
# TODO: support arbitrary batch size when forward_only=True
|
||||
self.microbatch_offset = 0
|
||||
if not self._use_microbatch_size:
|
||||
assert (
|
||||
|
@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = False
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = False
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
)
|
||||
self.send_metadata_forward = False
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
|
@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
)
|
||||
self.send_metadata_backward = False
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
return self.comm.send_forward(input_object, next_rank)
|
||||
elif self.stage_manager.is_last_stage():
|
||||
return self.comm.recv_forward(prev_rank)
|
||||
else:
|
||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
|
@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
|
@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
|
@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
|
||||
if self.forward_only:
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward()
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.recv_backward()
|
||||
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(output_obj)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.recv_forward()
|
||||
self.send_backward(input_obj_grad)
|
||||
input_obj = self.send_backward_recv_forward(input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
if not self.forward_only:
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
|
@ -68,45 +69,39 @@ class PipelineStageManager:
|
|||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
|
||||
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
||||
"""Is the current stage the first stage.
|
||||
|
||||
NOTE:
|
||||
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
|
||||
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()
|
||||
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
|
||||
|
||||
Returns:
|
||||
bool: Whether the current stage is the first stage.
|
||||
"""
|
||||
if self.is_interleave and model_chunk_id is None:
|
||||
model_chunk_id = self.model_chunk_id
|
||||
assert self.is_interleave ^ (
|
||||
model_chunk_id is None
|
||||
), "model_chunk_id must be specified when using interleaved pipeline"
|
||||
if not self.is_interleave or model_chunk_id < 0:
|
||||
assert isinstance(ignore_chunk, bool)
|
||||
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||
if not self.is_interleave or ignore_chunk:
|
||||
return self.stage == 0
|
||||
else:
|
||||
return self.stage == 0 and model_chunk_id == 0
|
||||
return self.stage == 0 and self.model_chunk_id == 0
|
||||
|
||||
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
|
||||
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
|
||||
"""Is the current stage the last stage.
|
||||
|
||||
NOTE:
|
||||
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
|
||||
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()
|
||||
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
|
||||
|
||||
Returns:
|
||||
bool: Whether the current stage is the last stage.
|
||||
"""
|
||||
if self.is_interleave and model_chunk_id is None:
|
||||
model_chunk_id = self.model_chunk_id
|
||||
assert self.is_interleave ^ (
|
||||
model_chunk_id is None
|
||||
), "model_chunk_id must be specified when using interleaved pipeline"
|
||||
if not self.is_interleave or model_chunk_id < 0:
|
||||
assert isinstance(ignore_chunk, bool)
|
||||
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||
if not self.is_interleave or ignore_chunk:
|
||||
return self.stage == self.num_stages - 1
|
||||
else:
|
||||
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1
|
||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||
|
||||
@property
|
||||
def num_stages(self) -> int:
|
||||
|
@ -174,3 +169,10 @@ class PipelineStageManager:
|
|||
ProcessGroup: Process group of the given stages.
|
||||
"""
|
||||
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def switch_model_chunk_id(self, model_chunk_id: int):
|
||||
old_model_chunk_id = self.model_chunk_id
|
||||
self.model_chunk_id = model_chunk_id
|
||||
yield
|
||||
self.model_chunk_id = old_model_chunk_id
|
||||
|
|
|
@ -309,11 +309,11 @@ class BertPolicy(Policy):
|
|||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
if stage_manager.is_first_stage(-1):
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embeddings)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(-1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.pooler)
|
||||
|
||||
else:
|
||||
|
@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
|
|||
"""Get pipeline layers for current stage"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
|
||||
return held_layers
|
||||
|
@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
|
@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
|
@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
|
@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
|
|
|
@ -8,7 +8,11 @@ from torch.nn import Module
|
|||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy
|
||||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
get_llama_flash_attention_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||
|
@ -140,21 +144,42 @@ class LlamaPolicy(Policy):
|
|||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
if self.pipeline_stage_manager is None:
|
||||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "LlamaModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = Policy.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
method_replacement = {
|
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)}
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
|
@ -167,13 +192,32 @@ class LlamaPolicy(Policy):
|
|||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = Policy.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col
|
||||
)
|
||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
|
@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
|
|
|
@ -57,9 +57,7 @@ def evaluate_model(
|
|||
|
||||
def evaluate_subset(dataloader: DataLoader):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
||||
)
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
for batch in dataloader:
|
||||
|
@ -136,9 +134,7 @@ def train_epoch(
|
|||
coordinator: DistCoordinator,
|
||||
):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
||||
)
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
|
|
|
@ -133,7 +133,9 @@ def main():
|
|||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style="interleaved",
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=2,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
precision="bf16",
|
||||
|
|
|
@ -1,47 +1,80 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
WORLD_SIZE = 2
|
||||
|
||||
|
||||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(2)
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_current_device())
|
||||
data = [
|
||||
"tensor",
|
||||
tensor,
|
||||
[tensor],
|
||||
{"tensor": tensor},
|
||||
]
|
||||
|
||||
if rank == 0:
|
||||
p2p.send_forward(tensor)
|
||||
p2p.send_forward([tensor])
|
||||
p2p.send_forward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_forward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i])
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
p2p.send_backward(tensor)
|
||||
p2p.send_backward([tensor])
|
||||
p2p.send_backward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_backward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i])
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
tensor_metadata = TensorMetadata(
|
||||
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||
)
|
||||
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=comm_metadata,
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipeline_p2p():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, WORLD_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -37,12 +37,13 @@ def pp_linear_fwd(
|
|||
stage_mgr: PipelineStageManager = None,
|
||||
model_chunk_id: int = None,
|
||||
):
|
||||
if stage_mgr.is_first_stage(model_chunk_id):
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage(model_chunk_id):
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||
if stage_mgr.is_first_stage():
|
||||
return {"input_obj": forward(data)}
|
||||
elif stage_mgr.is_last_stage():
|
||||
return forward(input_obj)
|
||||
else:
|
||||
return {"input_obj": forward(input_obj)}
|
||||
|
||||
|
||||
def run_pp(
|
||||
|
@ -107,7 +108,7 @@ def run_pp(
|
|||
)
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage(-1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
|
@ -119,6 +120,7 @@ def run_pp(
|
|||
# step
|
||||
torch_optimizer.step()
|
||||
pp_optimizer.step()
|
||||
pp_optimizer.zero_grad()
|
||||
|
||||
# check updated param
|
||||
for i in range(num_model_chunk):
|
||||
|
@ -126,6 +128,24 @@ def run_pp(
|
|||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||
|
|
|
@ -4,6 +4,7 @@ from types import MethodType
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
|
@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
|
||||
DIM = 8
|
||||
NUM_LAYER = 8
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(4, 8)
|
||||
self.linear2 = nn.Linear(8, 4)
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def pp_linear_fwd(
|
||||
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None
|
||||
forward,
|
||||
data: torch.Tensor = None,
|
||||
input_obj: torch.Tensor = None,
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
):
|
||||
if stage_mgr.is_first_stage():
|
||||
return {"input_obj": forward(data)}
|
||||
|
@ -38,34 +44,45 @@ def pp_linear_fwd(
|
|||
return {"input_obj": forward(input_obj)}
|
||||
|
||||
|
||||
def examine_pp():
|
||||
def examine_pp(num_microbatch: int, batch_size: int):
|
||||
"""
|
||||
This test is to examine the correctness of 1F1B, compared with torch.
|
||||
Be aware it contains some hardcodes.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
local_rank = torch.distributed.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
dist.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
NUM_MICRO_BATCHS = 4
|
||||
BATCH_SIZE = 4
|
||||
|
||||
# create models
|
||||
torch_model = MlpModel().cuda()
|
||||
|
||||
pp_model = copy.deepcopy(torch_model).cuda()
|
||||
|
||||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
|
||||
pg_mesh = ProcessGroupMesh(world_size)
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
|
||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
|
||||
|
||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
||||
if idx % (world_size) == local_rank:
|
||||
sharded_model = sub_model.cuda()
|
||||
rank = dist.get_rank()
|
||||
sharded_model = torch.nn.ModuleList()
|
||||
num_local_layer = NUM_LAYER // world_size
|
||||
for idx, sub_model in enumerate(pp_model.layers):
|
||||
if idx // num_local_layer == rank:
|
||||
sharded_model.append(sub_model.cuda())
|
||||
assert len(sharded_model) == num_local_layer
|
||||
|
||||
sharded_model._forward = sharded_model.forward
|
||||
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward)
|
||||
def custom_fwd(self, x):
|
||||
for layer in self._modules.values():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
sharded_model._forward = MethodType(custom_fwd, sharded_model)
|
||||
sharded_model.forward = MethodType(
|
||||
partial(
|
||||
pp_linear_fwd,
|
||||
stage_mgr=stage_manager,
|
||||
),
|
||||
sharded_model._forward,
|
||||
)
|
||||
|
||||
# create optimizer
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
@ -73,19 +90,15 @@ def examine_pp():
|
|||
|
||||
# create
|
||||
seed_all(1453)
|
||||
if stage_manager.is_first_stage():
|
||||
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
|
||||
else:
|
||||
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
|
||||
torch.distributed.all_reduce(input_list[0])
|
||||
input_list = [torch.rand(batch_size, DIM).cuda()]
|
||||
dist.all_reduce(input_list[0])
|
||||
|
||||
criterion = lambda x, y: torch.mean(x)
|
||||
criterion = lambda x, *arg, **kwargs: (x * x).mean()
|
||||
|
||||
# forward and backward
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output, _)
|
||||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
|
@ -95,34 +108,66 @@ def examine_pp():
|
|||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
torch_grad = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_grad.append(torch_p.grad.data)
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
pp_optimizer.step()
|
||||
pp_optimizer.zero_grad()
|
||||
|
||||
# check updated param
|
||||
torch_param = []
|
||||
for torch_p in torch_model.parameters():
|
||||
torch_param.append(torch_p.data)
|
||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
||||
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
examine_pp()
|
||||
examine_pp(num_microbatch, batch_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||
@pytest.mark.parametrize("batch_size", [12])
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
spawn(run_dist, 2)
|
||||
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
|
||||
assert NUM_LAYER % world_size == 0
|
||||
spawn(
|
||||
run_dist,
|
||||
world_size,
|
||||
num_microbatch=num_microbatch,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp()
|
||||
test_pp(num_microbatch=4, batch_size=4, world_size=4)
|
||||
|
|
|
@ -203,7 +203,7 @@ def check_output_hidden_state(
|
|||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
@ -229,6 +229,10 @@ def check_weight(
|
|||
org_weight = getattr_(org_model, suffix).weight
|
||||
sharded_weight = getattr_(sharded_model, suffix).weight
|
||||
|
||||
# skip if layer is not held by this process
|
||||
if sharded_weight is None:
|
||||
continue
|
||||
|
||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||
sharded_weight_list = [
|
||||
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
|
||||
|
|
|
@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
||||
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
|
@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
|
@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
|
@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
@ -183,6 +184,17 @@ def run_bert_test(test_config):
|
|||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_bert_3d_test(test_config):
|
||||
|
|
|
@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
|
@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
|
@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
|
@ -179,6 +179,17 @@ def run_llama_test(test_config):
|
|||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_llama_3d_test(test_config):
|
||||
|
|
Loading…
Reference in New Issue