mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPluginpull/5207/head
parent
af952673f7
commit
4fa689fca1
|
@ -1,4 +1,5 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
|
if os.getenv("NCCL_BUFFSIZE") is None:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(
|
||||||
|
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
|
||||||
|
)
|
||||||
|
os.environ["NCCL_BUFFSIZE"] = "134217728"
|
||||||
|
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
import io
|
import io
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.distributed import distributed_c10d as c10d
|
from torch.distributed import distributed_c10d as c10d
|
||||||
|
@ -20,7 +20,7 @@ from .stage_manager import PipelineStageManager
|
||||||
_unpickler = pickle.Unpickler
|
_unpickler = pickle.Unpickler
|
||||||
|
|
||||||
|
|
||||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
||||||
"""transform tensor to object with unpickle.
|
"""transform tensor to object with unpickle.
|
||||||
Info of the device in bytes stream will be modified into current device before unpickling
|
Info of the device in bytes stream will be modified into current device before unpickling
|
||||||
|
|
||||||
|
@ -48,21 +48,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||||
return unpickle
|
return unpickle
|
||||||
|
|
||||||
|
|
||||||
def check_for_nccl_backend(group):
|
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
||||||
pg = group or c10d._get_default_group()
|
|
||||||
# Gate PG wrapper check on Gloo availability.
|
|
||||||
if c10d._GLOO_AVAILABLE:
|
|
||||||
# It is not expected for PG to be wrapped many times, but support it just
|
|
||||||
# in case
|
|
||||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
|
||||||
pg = pg.wrapped_pg
|
|
||||||
|
|
||||||
return (
|
|
||||||
c10d.is_nccl_available() and
|
|
||||||
pg.name() == c10d.Backend.NCCL
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_object_list(
|
def _broadcast_object_list(
|
||||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||||
):
|
):
|
||||||
|
@ -70,13 +56,11 @@ def _broadcast_object_list(
|
||||||
The only difference is that object will be move to correct device after unpickled.
|
The only difference is that object will be move to correct device after unpickled.
|
||||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||||
be updated with data sent from rank src.
|
be updated with data sent from rank src.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_list (List[Any]): list of object to broadcast
|
object_list (List[Any]): list of object to broadcast
|
||||||
src (int): source rank to broadcast
|
src (int): source rank to broadcast
|
||||||
dst (int): dst rank to broadcast
|
dst (int): dst rank to broadcast
|
||||||
device (:class:`torch.device`): device to do broadcast. current device in default
|
device (:class:`torch.device`): device to do broadcast. current device in default
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if c10d._rank_not_in_group(group):
|
if c10d._rank_not_in_group(group):
|
||||||
|
@ -131,7 +115,7 @@ def _broadcast_object_list(
|
||||||
|
|
||||||
if my_rank != src:
|
if my_rank != src:
|
||||||
for i, obj_size in enumerate(object_sizes_tensor):
|
for i, obj_size in enumerate(object_sizes_tensor):
|
||||||
obj_view = object_tensor[offset: offset + obj_size]
|
obj_view = object_tensor[offset : offset + obj_size]
|
||||||
obj_view = obj_view.type(torch.uint8)
|
obj_view = obj_view.type(torch.uint8)
|
||||||
if obj_view.device != torch.device("cpu"):
|
if obj_view.device != torch.device("cpu"):
|
||||||
obj_view = obj_view.cpu()
|
obj_view = obj_view.cpu()
|
||||||
|
@ -149,6 +133,18 @@ def _broadcast_object_list(
|
||||||
object_list[i] = unpickle_object
|
object_list[i] = unpickle_object
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_nccl_backend(group):
|
||||||
|
pg = group or c10d._get_default_group()
|
||||||
|
# Gate PG wrapper check on Gloo availability.
|
||||||
|
if c10d._GLOO_AVAILABLE:
|
||||||
|
# It is not expected for PG to be wrapped many times, but support it just
|
||||||
|
# in case
|
||||||
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||||
|
pg = pg.wrapped_pg
|
||||||
|
|
||||||
|
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||||
|
|
||||||
|
|
||||||
def check_device(group):
|
def check_device(group):
|
||||||
is_nccl_backend = check_for_nccl_backend(group)
|
is_nccl_backend = check_for_nccl_backend(group)
|
||||||
current_device = None
|
current_device = None
|
||||||
|
@ -159,14 +155,14 @@ def check_device(group):
|
||||||
return current_device, is_nccl_backend
|
return current_device, is_nccl_backend
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
|
||||||
|
|
||||||
|
|
||||||
class P2PDataType(Enum):
|
class P2PDataType(Enum):
|
||||||
serialization = 0
|
Serialization = 0
|
||||||
tensor = 1
|
Tensor = 1
|
||||||
list = 2
|
List = 2
|
||||||
dict = 3
|
Dict = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -175,45 +171,71 @@ class P2PMetadata:
|
||||||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||||
|
|
||||||
|
|
||||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
obj = obj.contiguous()
|
obj = obj.contiguous()
|
||||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||||
ops_queue.append(op_to_add)
|
ops_queue.append(op_to_add)
|
||||||
else:
|
else:
|
||||||
for tensor_to_comm in obj:
|
for tensor_to_comm in obj:
|
||||||
tensor_to_comm = tensor_to_comm.contiguous()
|
assert isinstance(tensor_to_comm, torch.Tensor)
|
||||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||||
ops_queue.append(op_to_add)
|
|
||||||
|
|
||||||
|
|
||||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
|
||||||
if p2p_metadata.data_type == P2PDataType.tensor:
|
if p2p_metadata.data_type == P2PDataType.Tensor:
|
||||||
metadata = p2p_metadata.content
|
metadata = p2p_metadata.content
|
||||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
tensor_recv = torch.empty(
|
||||||
|
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||||
|
)
|
||||||
return tensor_recv
|
return tensor_recv
|
||||||
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
|
||||||
buffer_recv = []
|
buffer_recv = []
|
||||||
for metadata in p2p_metadata.content:
|
for metadata in p2p_metadata.content:
|
||||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
tensor_recv = torch.empty(
|
||||||
|
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||||
|
)
|
||||||
buffer_recv.append(tensor_recv)
|
buffer_recv.append(tensor_recv)
|
||||||
return buffer_recv
|
return buffer_recv
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||||
|
|
||||||
|
|
||||||
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
def create_fast_send_metadata(object: Any) -> P2PMetadata:
|
||||||
|
assert _check_if_fast_send_available(object)
|
||||||
|
if isinstance(object, torch.Tensor):
|
||||||
|
data_type = P2PDataType.Tensor
|
||||||
|
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||||
|
elif isinstance(object, list):
|
||||||
|
data_type = P2PDataType.List
|
||||||
|
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
|
||||||
|
elif isinstance(object, dict):
|
||||||
|
data_type = P2PDataType.Dict
|
||||||
|
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
|
||||||
|
return P2PMetadata(data_type, content)
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_send_recv_tensor(
|
||||||
|
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||||
|
recv_tensor_metadata: Optional[P2PMetadata],
|
||||||
|
send_dst: Optional[int],
|
||||||
|
recv_src: Optional[int],
|
||||||
|
send_group: Optional[ProcessGroup],
|
||||||
|
recv_group: Optional[ProcessGroup],
|
||||||
|
current_device: Any,
|
||||||
|
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||||
buffer_recv = None
|
buffer_recv = None
|
||||||
if recv_tensor_metadata is not None:
|
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
|
||||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||||
|
|
||||||
ops = []
|
ops = []
|
||||||
|
if send_dst is not None and send_tensor_list is not None:
|
||||||
if send_dst is not None:
|
assert send_group is not None
|
||||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||||
|
if recv_src is not None and buffer_recv is not None:
|
||||||
if recv_src is not None:
|
assert recv_group is not None
|
||||||
assert buffer_recv is not None
|
|
||||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||||
|
|
||||||
if len(ops) > 0:
|
if len(ops) > 0:
|
||||||
|
@ -221,24 +243,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
req.wait()
|
req.wait()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Remove synchronization according to Pytorch's documentation
|
# Remove synchronization according to Pytorch's documentation
|
||||||
# However, the Megatron-LM does synchronization here
|
# However, the Megatron-LM does synchronization here
|
||||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return buffer_recv
|
return buffer_recv
|
||||||
|
|
||||||
|
|
||||||
def _send_recv_serialization_object(
|
def _send_recv_serialization_object(
|
||||||
object: Any,
|
object: Any,
|
||||||
send_dst: Optional[int], recv_src: Optional[int],
|
send_dst: Optional[int],
|
||||||
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
recv_src: Optional[int],
|
||||||
current_device,
|
send_group: Optional[ProcessGroup],
|
||||||
is_nccl_backend):
|
recv_group: Optional[ProcessGroup],
|
||||||
|
current_device: Any,
|
||||||
|
is_nccl_backend: bool,
|
||||||
|
) -> Optional[P2PMetadata]:
|
||||||
ops = []
|
ops = []
|
||||||
|
|
||||||
send_object_tensor = None
|
send_object_tensor = None
|
||||||
if object is not None and send_dst is not None:
|
if object is not None and send_dst is not None:
|
||||||
if Version(torch.__version__) >= Version("1.13.0"):
|
if Version(torch.__version__) >= Version("1.13.0"):
|
||||||
|
@ -264,10 +288,8 @@ def _send_recv_serialization_object(
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
req.wait()
|
req.wait()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# See the comment in `_batch_send_recv_tensor`
|
# See the comment in `_batch_send_recv_tensor`
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
ops = []
|
ops = []
|
||||||
|
|
||||||
|
@ -286,52 +308,77 @@ def _send_recv_serialization_object(
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
req.wait()
|
req.wait()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# See the comment in `_batch_send_recv_tensor`
|
# See the comment in `_batch_send_recv_tensor`
|
||||||
# torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||||
if recv_object_tensor.device != torch.device("cpu"):
|
if recv_object_tensor.device != torch.device("cpu"):
|
||||||
recv_object_tensor = recv_object_tensor.cpu()
|
recv_object_tensor = recv_object_tensor.cpu()
|
||||||
|
|
||||||
unpickle_object = _cuda_safe_tensor_to_object(
|
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||||
recv_object_tensor, recv_object_size_tensor.item())
|
|
||||||
|
|
||||||
if (
|
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||||
isinstance(unpickle_object, torch.Tensor)
|
|
||||||
and unpickle_object.device.index != torch.cuda.current_device()
|
|
||||||
):
|
|
||||||
unpickle_object = unpickle_object.cuda()
|
unpickle_object = unpickle_object.cuda()
|
||||||
|
|
||||||
return unpickle_object
|
return unpickle_object
|
||||||
|
|
||||||
|
|
||||||
def _check_if_fast_send_available(object):
|
def _check_if_fast_send_available(object: Any) -> bool:
|
||||||
if type(object) is torch.Tensor:
|
if isinstance(object, torch.Tensor):
|
||||||
return True
|
return True
|
||||||
elif type(object) is list:
|
elif isinstance(object, list):
|
||||||
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
|
||||||
return is_list_of_tensor
|
return is_list_of_tensor
|
||||||
elif type(object) is dict:
|
elif isinstance(object, dict):
|
||||||
is_dict_of_tensor = all([type(k) is str and type(
|
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
|
||||||
v) is torch.Tensor for k, v in object.items()])
|
|
||||||
|
|
||||||
return is_dict_of_tensor
|
return is_dict_of_tensor
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _communicate(
|
def _communicate(
|
||||||
object,
|
object: Any,
|
||||||
send_dst: Optional[int],
|
send_dst: Optional[int],
|
||||||
recv_src: Optional[int],
|
recv_src: Optional[int],
|
||||||
send_group: Optional[ProcessGroup] = None,
|
send_group: Optional[ProcessGroup] = None,
|
||||||
recv_group: Optional[ProcessGroup] = None,
|
recv_group: Optional[ProcessGroup] = None,
|
||||||
|
send_metadata: bool = True,
|
||||||
|
metadata_recv: Optional[P2PMetadata] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
"""
|
||||||
c10d._warn_not_in_group("_communicate")
|
Send and receive object from send_dst and recv_src respectively
|
||||||
return
|
|
||||||
|
Args:
|
||||||
|
object (Any): object needed to be sent
|
||||||
|
send_dst (int): rank of the destination
|
||||||
|
recv_src (int): rank of the source
|
||||||
|
send_group (ProcessGroup, optional): process group of sender
|
||||||
|
recv_group (ProcessGroup, optional): process group of receiver
|
||||||
|
send_metadata (bool, optional): whether to send metadata
|
||||||
|
metadata_recv (P2PMetadata, optional): metadata of the object to be received
|
||||||
|
"""
|
||||||
|
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
||||||
|
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
||||||
|
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
||||||
|
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
|
||||||
|
assert (
|
||||||
|
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
|
||||||
|
), "metadata_recv type must not be Serialization"
|
||||||
|
|
||||||
|
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||||
|
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||||
|
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||||
|
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||||
|
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||||
|
|
||||||
|
# NOTE: only the following 5 cases are valid:
|
||||||
|
# 1. send() [needs extra metadata] and no recv()
|
||||||
|
# 2. recv() [needs extra metadata] and no send()
|
||||||
|
# 3. neither send() nor recv() need extra metadata
|
||||||
|
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||||
|
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||||
|
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||||
|
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||||
|
|
||||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||||
|
@ -341,67 +388,56 @@ def _communicate(
|
||||||
assert current_send_device == current_recv_device
|
assert current_send_device == current_recv_device
|
||||||
current_device = current_send_device
|
current_device = current_send_device
|
||||||
|
|
||||||
assert (send_dst is not None) or (recv_src is not None)
|
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
||||||
|
metadata_send = None
|
||||||
can_fast_send = False
|
if send_dst is not None and send_metadata:
|
||||||
send_metadata = None
|
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||||
if send_dst is not None:
|
if not can_fast_send:
|
||||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
|
||||||
if not can_fast_send:
|
|
||||||
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
|
||||||
else:
|
|
||||||
if type(object) is torch.Tensor:
|
|
||||||
data_type = P2PDataType.tensor
|
|
||||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
|
||||||
elif type(object) is list:
|
|
||||||
data_type = P2PDataType.list
|
|
||||||
content = []
|
|
||||||
for v in object:
|
|
||||||
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
|
||||||
elif type(object) is dict:
|
|
||||||
data_type = P2PDataType.dict
|
|
||||||
content = []
|
|
||||||
for k, v in object.items():
|
|
||||||
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
metadata_send = create_fast_send_metadata(object)
|
||||||
send_metadata = P2PMetadata(data_type, content)
|
|
||||||
|
|
||||||
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
# Send and receive metadata
|
||||||
if recv_metadata is not None:
|
_metadata_recv = _send_recv_serialization_object(
|
||||||
assert type(recv_metadata) is P2PMetadata
|
object=metadata_send,
|
||||||
if recv_metadata.data_type == P2PDataType.serialization:
|
send_dst=send_dst if send_metadata else None,
|
||||||
return recv_metadata.content
|
recv_src=recv_src if metadata_recv is None else None,
|
||||||
if not can_fast_send and send_dst is not None:
|
send_group=send_group if send_metadata else None,
|
||||||
return
|
recv_group=recv_group if metadata_recv is None else None,
|
||||||
|
current_device=current_device,
|
||||||
|
is_nccl_backend=is_nccl_backend,
|
||||||
|
)
|
||||||
|
assert metadata_recv is None or _metadata_recv is None
|
||||||
|
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||||
|
|
||||||
send_tensor_list = None
|
send_tensor_list = None
|
||||||
if type(object) is torch.Tensor:
|
if isinstance(object, torch.Tensor):
|
||||||
send_tensor_list = object
|
send_tensor_list = object
|
||||||
elif type(object) is list:
|
elif isinstance(object, list):
|
||||||
send_tensor_list = object
|
send_tensor_list = object
|
||||||
elif type(object) is dict:
|
elif isinstance(object, dict):
|
||||||
send_tensor_list = list(object.values())
|
send_tensor_list = list(object.values())
|
||||||
|
|
||||||
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
# Send and receive data
|
||||||
|
recv_buffer = _batch_send_recv_tensor(
|
||||||
|
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
|
||||||
|
)
|
||||||
|
|
||||||
if recv_metadata is not None:
|
if metadata_recv is not None:
|
||||||
assert recv_buffer is not None
|
assert isinstance(metadata_recv, P2PMetadata)
|
||||||
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
if metadata_recv.data_type == P2PDataType.Serialization:
|
||||||
return recv_buffer
|
return metadata_recv.content
|
||||||
elif recv_metadata.data_type == P2PDataType.dict:
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in zip(
|
|
||||||
[m.key for m in recv_metadata.content],
|
|
||||||
recv_buffer,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
assert recv_buffer is not None
|
||||||
|
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
|
||||||
|
return recv_buffer
|
||||||
|
elif metadata_recv.data_type == P2PDataType.Dict:
|
||||||
|
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||||
|
|
||||||
|
|
||||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
|
||||||
"""send anything to dst rank
|
"""send anything to dst rank
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -411,10 +447,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
|
||||||
|
|
||||||
|
|
||||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
|
||||||
"""recv anything from src
|
"""recv anything from src
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -423,7 +459,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||||
Returns:
|
Returns:
|
||||||
Any: Object received from src.
|
Any: Object received from src.
|
||||||
"""
|
"""
|
||||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
|
||||||
|
|
||||||
|
|
||||||
def _p2p_comm(
|
def _p2p_comm(
|
||||||
|
@ -436,7 +472,7 @@ def _p2p_comm(
|
||||||
"""
|
"""
|
||||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||||
|
|
||||||
Agrs:
|
Args:
|
||||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||||
recv_prev (bool): whether to receive tensor from previous stage
|
recv_prev (bool): whether to receive tensor from previous stage
|
||||||
peer (int): rank of the peer
|
peer (int): rank of the peer
|
||||||
|
@ -467,7 +503,6 @@ def _p2p_comm(
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
ops.append(recv_prev_op)
|
ops.append(recv_prev_op)
|
||||||
|
|
||||||
if len(ops) > 0:
|
if len(ops) > 0:
|
||||||
reqs = dist.batch_isend_irecv(ops)
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
|
@ -490,7 +525,6 @@ def _p2p_comm(
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
ops.append(send_next_op)
|
ops.append(send_next_op)
|
||||||
|
|
||||||
if tensor_recv_prev is not None:
|
if tensor_recv_prev is not None:
|
||||||
recv_prev_op = dist.P2POp(
|
recv_prev_op = dist.P2POp(
|
||||||
dist.irecv,
|
dist.irecv,
|
||||||
|
@ -510,7 +544,7 @@ class PipelineP2PCommunication:
|
||||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||||
self.stage_manager = stage_manager
|
self.stage_manager = stage_manager
|
||||||
|
|
||||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -522,11 +556,13 @@ class PipelineP2PCommunication:
|
||||||
if prev_rank is None:
|
if prev_rank is None:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
|
input_tensor = _recv_object(
|
||||||
|
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
|
||||||
|
)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
def recv_backward(self, next_rank: int = None) -> Any:
|
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -539,12 +575,12 @@ class PipelineP2PCommunication:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
output_tensor_grad = _recv_object(
|
output_tensor_grad = _recv_object(
|
||||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
|
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -554,9 +590,15 @@ class PipelineP2PCommunication:
|
||||||
if next_rank is None:
|
if next_rank is None:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
|
_send_object(
|
||||||
|
output_object,
|
||||||
|
cur_rank,
|
||||||
|
next_rank,
|
||||||
|
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||||
|
send_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -566,9 +608,21 @@ class PipelineP2PCommunication:
|
||||||
if prev_rank is None:
|
if prev_rank is None:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
_send_object(
|
||||||
|
input_object,
|
||||||
|
cur_rank,
|
||||||
|
prev_rank,
|
||||||
|
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||||
|
send_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
def send_forward_recv_backward(
|
||||||
|
self,
|
||||||
|
input_object: Any,
|
||||||
|
next_rank: Optional[int] = None,
|
||||||
|
send_metadata: bool = True,
|
||||||
|
metadata_recv: Optional[P2PMetadata] = None,
|
||||||
|
) -> Any:
|
||||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -581,11 +635,22 @@ class PipelineP2PCommunication:
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||||
return _communicate(
|
return _communicate(
|
||||||
input_object, next_rank, next_rank,
|
input_object,
|
||||||
send_group=group, recv_group=group,
|
next_rank,
|
||||||
|
next_rank,
|
||||||
|
send_group=group,
|
||||||
|
recv_group=group,
|
||||||
|
send_metadata=send_metadata,
|
||||||
|
metadata_recv=metadata_recv,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
def send_backward_recv_forward(
|
||||||
|
self,
|
||||||
|
input_object: Any,
|
||||||
|
prev_rank: Optional[int] = None,
|
||||||
|
send_metadata: bool = True,
|
||||||
|
metadata_recv: Optional[P2PMetadata] = None,
|
||||||
|
) -> Any:
|
||||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -597,37 +662,22 @@ class PipelineP2PCommunication:
|
||||||
|
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||||
return _communicate(
|
|
||||||
input_object, prev_rank, prev_rank,
|
|
||||||
send_group=group, recv_group=group,
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
|
||||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_object (Any): Object to be sent.
|
|
||||||
prev_rank (int, optional): The rank of the sender of the tensor
|
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor
|
|
||||||
"""
|
|
||||||
if prev_rank is None:
|
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
|
||||||
if next_rank is None:
|
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
|
||||||
|
|
||||||
cur_rank = self.stage_manager.get_rank()
|
|
||||||
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
|
||||||
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
|
||||||
return _communicate(
|
return _communicate(
|
||||||
input_object,
|
input_object,
|
||||||
send_dst=next_rank,
|
prev_rank,
|
||||||
recv_src=prev_rank,
|
prev_rank,
|
||||||
send_group=send_group,
|
send_group=group,
|
||||||
recv_group=recv_group,
|
recv_group=group,
|
||||||
|
send_metadata=send_metadata,
|
||||||
|
metadata_recv=metadata_recv,
|
||||||
)
|
)
|
||||||
|
|
||||||
def p2p_communicate(
|
def p2p_communicate(
|
||||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
self,
|
||||||
|
output_object: Any,
|
||||||
|
recv_pre: bool,
|
||||||
|
next_rank: Optional[int] = None,
|
||||||
|
comm_dtype: torch.dtype = torch.float16,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||||
|
@ -636,10 +686,14 @@ class PipelineP2PCommunication:
|
||||||
output_object (Any): Object to be sent.
|
output_object (Any): Object to be sent.
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if peer is None:
|
if next_rank is None:
|
||||||
peer = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
recv_tensor = _p2p_comm(
|
recv_tensor = _p2p_comm(
|
||||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
output_object,
|
||||||
|
recv_pre,
|
||||||
|
next_rank,
|
||||||
|
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||||
|
comm_dtype,
|
||||||
)
|
)
|
||||||
return recv_tensor
|
return recv_tensor
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils.device import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
assert (
|
assert (
|
||||||
num_microbatch is not None or microbatch_size is not None
|
num_microbatch is not None or microbatch_size is not None
|
||||||
), "Either num_microbatch or microbatch_size should be provided"
|
), "Either num_microbatch or microbatch_size should be provided"
|
||||||
|
|
||||||
self.comm = PipelineP2PCommunication(stage_manager)
|
self.comm = PipelineP2PCommunication(stage_manager)
|
||||||
self.num_microbatch = num_microbatch
|
self.num_microbatch = num_microbatch
|
||||||
self.microbatch_size = microbatch_size
|
self.microbatch_size = microbatch_size
|
||||||
|
@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
|
|
||||||
self.batch: Any
|
self.batch: Any
|
||||||
self.batch_size: int
|
self.batch_size: int
|
||||||
|
self.last_batch_size: Optional[int] = None
|
||||||
self.microbatch_offset: List[int]
|
self.microbatch_offset: List[int]
|
||||||
|
|
||||||
|
# P2PMeta cache
|
||||||
|
self.send_metadata_forward = True
|
||||||
|
self.send_metadata_backward = True
|
||||||
|
self.metadata_recv_forward = None
|
||||||
|
self.metadata_recv_backward = None
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
|
||||||
|
@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
batch = tree_map(partial(to_device, device=device), batch)
|
batch = tree_map(partial(to_device, device=device), batch)
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
|
if self.last_batch_size is None:
|
||||||
|
self.last_batch_size = self.batch_size
|
||||||
|
else:
|
||||||
|
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||||
|
# TODO: support arbitrary batch size when forward_only=True
|
||||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||||
if self.num_microbatch is not None:
|
if self.num_microbatch is not None:
|
||||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||||
|
@ -106,12 +119,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The input tensor or input tensor list.
|
Any: The input tensor or input tensor list.
|
||||||
"""
|
"""
|
||||||
if self.stage_manager.is_first_stage(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
input_tensor = None
|
if not self.stage_manager.is_first_stage():
|
||||||
else:
|
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||||
input_tensor = self.comm.recv_forward(prev_rank)
|
if self.metadata_recv_forward is None:
|
||||||
|
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
|
@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The input gradient tensor or gradient tensor list.
|
Any: The input gradient tensor or gradient tensor list.
|
||||||
"""
|
"""
|
||||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
output_tensor_grad = None
|
if not self.stage_manager.is_last_stage():
|
||||||
else:
|
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
if self.metadata_recv_backward is None:
|
||||||
|
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||||
|
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
|
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
For interleaved 1F1B.
|
For interleaved 1F1B.
|
||||||
|
|
||||||
|
@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
output_object (Any): Object to be sent.
|
output_object (Any): Object to be sent.
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if not self.stage_manager.is_last_stage(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
self.comm.send_forward(output_object, next_rank)
|
if not self.stage_manager.is_last_stage():
|
||||||
|
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||||
|
self.send_metadata_forward = False
|
||||||
|
|
||||||
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
|
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
For interleaved 1F1B.
|
For interleaved 1F1B.
|
||||||
|
|
||||||
|
@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
input_object (Any): Object to be sent.
|
input_object (Any): Object to be sent.
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||||
"""
|
"""
|
||||||
if not self.stage_manager.is_first_stage(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
self.comm.send_backward(input_object, prev_rank)
|
if not self.stage_manager.is_first_stage():
|
||||||
|
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||||
|
self.send_metadata_backward = False
|
||||||
|
|
||||||
|
def send_forward_recv_backward(
|
||||||
|
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
|
||||||
|
) -> Any:
|
||||||
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
|
if not self.stage_manager.is_last_stage():
|
||||||
|
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||||
|
output_object,
|
||||||
|
next_rank,
|
||||||
|
send_metadata=self.send_metadata_forward,
|
||||||
|
metadata_recv=self.metadata_recv_backward,
|
||||||
|
)
|
||||||
|
self.send_metadata_forward = False
|
||||||
|
if self.metadata_recv_backward is None:
|
||||||
|
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||||
|
|
||||||
|
return output_tensor_grad
|
||||||
|
|
||||||
|
def send_backward_recv_forward(
|
||||||
|
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
|
||||||
|
) -> Any:
|
||||||
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
|
if not self.stage_manager.is_first_stage():
|
||||||
|
input_tensor = self.comm.send_backward_recv_forward(
|
||||||
|
output_object,
|
||||||
|
prev_rank,
|
||||||
|
send_metadata=self.send_metadata_backward,
|
||||||
|
metadata_recv=self.metadata_recv_forward,
|
||||||
|
)
|
||||||
|
self.send_metadata_backward = False
|
||||||
|
if self.metadata_recv_forward is None:
|
||||||
|
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||||
|
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
def forward_step(
|
def forward_step(
|
||||||
self,
|
self,
|
||||||
|
@ -180,25 +233,24 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
# for the first stage, input_obj is None
|
# for the first stage, input_obj is None
|
||||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||||
|
|
||||||
self.stage_manager.model_chunk_id = model_chunk_id
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
if isinstance(model_chunk, ModuleList):
|
if isinstance(model_chunk, ModuleList):
|
||||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||||
else:
|
else:
|
||||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||||
internal_inputs = {} if input_obj is None else input_obj
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||||
self.stage_manager.model_chunk_id = None
|
|
||||||
|
|
||||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
if self.stage_manager.is_last_stage():
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs.append(tree_map(detach, output_obj))
|
outputs.append(tree_map(detach, output_obj))
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return output_obj
|
return output_obj
|
||||||
|
|
||||||
def backward_step(
|
def backward_step(
|
||||||
self,
|
self,
|
||||||
|
@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dict with keys: 'loss' and 'outputs'.
|
dict: A dict with keys: 'loss' and 'outputs'.
|
||||||
"""
|
"""
|
||||||
# TODO: handle arbitrary batch size when forward_only == True
|
self.forward_only = not torch.is_grad_enabled()
|
||||||
forward_only = not torch.is_grad_enabled()
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
assert forward_only, "Optimizer should be passed when doing backward."
|
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||||
|
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
|
|
||||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||||
if forward_only:
|
if self.forward_only:
|
||||||
num_warmup_microbatch = num_microbatch
|
num_warmup_microbatch = num_microbatch
|
||||||
else:
|
else:
|
||||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||||
|
@ -288,43 +339,29 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
input_objs = None
|
input_objs = None
|
||||||
output_objs = None
|
output_objs = None
|
||||||
|
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
input_objs = [[] for _ in range(self.num_model_chunks)]
|
input_objs = [[] for _ in range(self.num_model_chunks)]
|
||||||
output_objs = [[] for _ in range(self.num_model_chunks)]
|
output_objs = [[] for _ in range(self.num_model_chunks)]
|
||||||
|
|
||||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None
|
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
if return_loss and self.stage_manager.is_last_stage(-1):
|
accum_loss = None
|
||||||
|
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
accum_loss = torch.zeros(1, device=get_current_device())
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
else:
|
|
||||||
accum_loss = None
|
|
||||||
|
|
||||||
# for ranks except the first one, get into recv state
|
|
||||||
input_obj = self.recv_forward(0)
|
|
||||||
|
|
||||||
# Run warmup forward passes.
|
# Run warmup forward passes.
|
||||||
for i in range(num_warmup_microbatch):
|
for i in range(num_warmup_microbatch):
|
||||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||||
# recv first on first rank to avoid sending or receiving at the same time
|
input_obj = self.recv_forward(model_chunk_id)
|
||||||
if self.stage_manager.is_first_stage(-1):
|
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||||
input_obj = self.recv_forward(model_chunk_id)
|
if not self.forward_only:
|
||||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
input_objs[model_chunk_id].append(input_obj)
|
||||||
self.send_forward(model_chunk_id, output_obj)
|
output_objs[model_chunk_id].append(output_obj)
|
||||||
if not forward_only:
|
self.send_forward(model_chunk_id, output_obj)
|
||||||
input_objs[model_chunk_id].append(input_obj)
|
|
||||||
output_objs[model_chunk_id].append(output_obj)
|
|
||||||
else:
|
|
||||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
|
||||||
if not forward_only:
|
|
||||||
input_objs[model_chunk_id].append(input_obj)
|
|
||||||
output_objs[model_chunk_id].append(output_obj)
|
|
||||||
self.send_forward(model_chunk_id, output_obj)
|
|
||||||
|
|
||||||
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
|
if num_microbatch_remaining > 0:
|
||||||
break
|
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
|
||||||
|
input_obj = self.recv_forward(model_chunk_id)
|
||||||
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
|
|
||||||
input_obj = self.recv_forward(model_chunk_id)
|
|
||||||
|
|
||||||
# Run 1F1B in steady state.
|
# Run 1F1B in steady state.
|
||||||
for i in range(num_microbatch_remaining):
|
for i in range(num_microbatch_remaining):
|
||||||
|
@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
last_iteration = i == num_microbatch_remaining - 1
|
last_iteration = i == num_microbatch_remaining - 1
|
||||||
|
|
||||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||||
if forward_only:
|
if self.forward_only:
|
||||||
self.send_forward(model_chunk_id, output_obj)
|
|
||||||
|
|
||||||
if not last_iteration:
|
if not last_iteration:
|
||||||
input_obj = self.recv_forward(model_chunk_id)
|
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
|
||||||
|
else:
|
||||||
|
self.send_forward(model_chunk_id, output_obj)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.send_forward(model_chunk_id, output_obj)
|
self.send_forward(model_chunk_id, output_obj)
|
||||||
|
@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||||
|
self.send_backward(model_chunk_id, input_obj_grad)
|
||||||
|
|
||||||
if last_iteration:
|
if not last_iteration:
|
||||||
input_obj = None
|
|
||||||
else:
|
|
||||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
|
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
|
||||||
input_obj = self.recv_forward(model_chunk_id)
|
input_obj = self.recv_forward(model_chunk_id)
|
||||||
|
|
||||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
|
||||||
self.send_backward(model_chunk_id, input_obj_grad)
|
|
||||||
|
|
||||||
# Run cooldown backward passes.
|
# Run cooldown backward passes.
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
for i in range(num_microbatch_remaining, num_microbatch):
|
for i in range(num_microbatch_remaining, num_microbatch):
|
||||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||||
input_obj = input_objs[model_chunk_id].pop(0)
|
input_obj = input_objs[model_chunk_id].pop(0)
|
||||||
|
@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||||
self.send_backward(model_chunk_id, input_obj_grad)
|
self.send_backward(model_chunk_id, input_obj_grad)
|
||||||
|
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||||
|
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils.device import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "Either num_microbatches or microbatch_size should be provided"
|
), "Either num_microbatches or microbatch_size should be provided"
|
||||||
|
|
||||||
self.comm = PipelineP2PCommunication(stage_manager)
|
self.comm = PipelineP2PCommunication(stage_manager)
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.microbatch_size = microbatch_size
|
self.microbatch_size = microbatch_size
|
||||||
self.batch: Optional[Any] = None
|
self.batch: Optional[Any] = None
|
||||||
self.batch_size: Optional[int] = None
|
self.batch_size: Optional[int] = None
|
||||||
|
self.last_batch_size: Optional[int] = None
|
||||||
self.microbatch_offset: Optional[int] = None
|
self.microbatch_offset: Optional[int] = None
|
||||||
self._use_microbatch_size = num_microbatches is None
|
self._use_microbatch_size = num_microbatches is None
|
||||||
|
|
||||||
|
# P2PMeta cache
|
||||||
|
self.send_metadata_forward = True
|
||||||
|
self.send_metadata_backward = True
|
||||||
|
self.metadata_recv_forward = None
|
||||||
|
self.metadata_recv_backward = None
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
|
||||||
|
@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
batch = next(data_iter)
|
batch = next(data_iter)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
batch = tree_map(partial(to_device, device=device), batch)
|
batch = tree_map(partial(to_device, device=device), batch)
|
||||||
|
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
|
if self.last_batch_size is None:
|
||||||
|
self.last_batch_size = self.batch_size
|
||||||
|
else:
|
||||||
|
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||||
|
# TODO: support arbitrary batch size when forward_only=True
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
if not self._use_microbatch_size:
|
if not self._use_microbatch_size:
|
||||||
assert (
|
assert (
|
||||||
|
@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The input tensor or input tensor list.
|
Any: The input tensor or input tensor list.
|
||||||
"""
|
"""
|
||||||
if self.stage_manager.is_first_stage():
|
if not self.stage_manager.is_first_stage():
|
||||||
input_tensor = None
|
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||||
else:
|
if self.metadata_recv_forward is None:
|
||||||
input_tensor = self.comm.recv_forward(prev_rank)
|
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
def recv_backward(self, next_rank: int = None) -> Any:
|
def recv_backward(self, next_rank: int = None) -> Any:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
|
@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Any: The input gradient tensor or gradient tensor list.
|
Any: The input gradient tensor or gradient tensor list.
|
||||||
"""
|
"""
|
||||||
if self.stage_manager.is_last_stage():
|
if not self.stage_manager.is_last_stage():
|
||||||
output_tensor_grad = None
|
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||||
else:
|
if self.metadata_recv_backward is None:
|
||||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||||
|
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
|
@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if not self.stage_manager.is_last_stage():
|
if not self.stage_manager.is_last_stage():
|
||||||
self.comm.send_forward(output_object, next_rank)
|
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||||
|
self.send_metadata_forward = False
|
||||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
|
||||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
|
||||||
For 1F1B.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_object (Any): Object to be sent.
|
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
|
||||||
"""
|
|
||||||
if not self.stage_manager.is_last_stage():
|
|
||||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
|
||||||
|
|
||||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
|
@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||||
"""
|
"""
|
||||||
if not self.stage_manager.is_first_stage():
|
if not self.stage_manager.is_first_stage():
|
||||||
self.comm.send_backward(input_object, prev_rank)
|
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||||
|
self.send_metadata_backward = False
|
||||||
|
|
||||||
|
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||||
|
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||||
|
For 1F1B.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_object (Any): Object to be sent.
|
||||||
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
|
"""
|
||||||
|
if not self.stage_manager.is_last_stage():
|
||||||
|
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||||
|
output_object,
|
||||||
|
next_rank,
|
||||||
|
send_metadata=self.send_metadata_forward,
|
||||||
|
metadata_recv=self.metadata_recv_backward,
|
||||||
|
)
|
||||||
|
self.send_metadata_forward = False
|
||||||
|
if self.metadata_recv_backward is None:
|
||||||
|
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||||
|
|
||||||
|
return output_tensor_grad
|
||||||
|
|
||||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||||
|
@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if not self.stage_manager.is_first_stage():
|
if not self.stage_manager.is_first_stage():
|
||||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
input_tensor = self.comm.send_backward_recv_forward(
|
||||||
|
output_object,
|
||||||
|
prev_rank,
|
||||||
|
send_metadata=self.send_metadata_backward,
|
||||||
|
metadata_recv=self.metadata_recv_forward,
|
||||||
|
)
|
||||||
|
self.send_metadata_backward = False
|
||||||
|
if self.metadata_recv_forward is None:
|
||||||
|
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||||
|
|
||||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
return input_tensor
|
||||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
|
||||||
For 1F1B.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_object (Any): Object to be sent.
|
|
||||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
|
||||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
|
||||||
"""
|
|
||||||
if self.stage_manager.is_first_stage():
|
|
||||||
return self.comm.send_forward(input_object, next_rank)
|
|
||||||
elif self.stage_manager.is_last_stage():
|
|
||||||
return self.comm.recv_forward(prev_rank)
|
|
||||||
else:
|
|
||||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
|
||||||
|
|
||||||
def forward_step(
|
def forward_step(
|
||||||
self,
|
self,
|
||||||
|
@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dict with keys: 'loss' and 'outputs'.
|
dict: A dict with keys: 'loss' and 'outputs'.
|
||||||
"""
|
"""
|
||||||
forward_only = not torch.is_grad_enabled()
|
|
||||||
|
self.forward_only = not torch.is_grad_enabled()
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
assert forward_only, "Optimizer should be passed when doing backward."
|
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||||
|
|
||||||
self.load_batch(data_iter)
|
self.load_batch(data_iter)
|
||||||
|
|
||||||
|
@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
input_objs = None
|
input_objs = None
|
||||||
output_objs = None
|
output_objs = None
|
||||||
|
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
input_objs = []
|
input_objs = []
|
||||||
output_objs = []
|
output_objs = []
|
||||||
|
|
||||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
accum_loss = None
|
||||||
if return_loss and self.stage_manager.is_last_stage():
|
if return_loss and self.stage_manager.is_last_stage():
|
||||||
accum_loss = torch.zeros(1, device=get_current_device())
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
else:
|
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||||
accum_loss = None
|
|
||||||
|
|
||||||
# Run warmup forward passes.
|
# Run warmup forward passes.
|
||||||
for i in range(num_warmup_microbatches):
|
for i in range(num_warmup_microbatches):
|
||||||
input_obj = self.recv_forward()
|
input_obj = self.recv_forward()
|
||||||
|
|
||||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||||
|
|
||||||
self.send_forward(output_obj)
|
self.send_forward(output_obj)
|
||||||
|
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
input_objs.append(input_obj)
|
input_objs.append(input_obj)
|
||||||
output_objs.append(output_obj)
|
output_objs.append(output_obj)
|
||||||
|
|
||||||
|
@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
last_iteration = i == (num_microbatches_remaining - 1)
|
last_iteration = i == (num_microbatches_remaining - 1)
|
||||||
|
|
||||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||||
if forward_only:
|
|
||||||
|
if self.forward_only:
|
||||||
self.send_forward(output_obj)
|
self.send_forward(output_obj)
|
||||||
|
|
||||||
if not last_iteration:
|
if not last_iteration:
|
||||||
input_obj = self.recv_forward()
|
input_obj = self.recv_forward()
|
||||||
else:
|
|
||||||
# TODO adjust here
|
|
||||||
self.send_forward(output_obj)
|
|
||||||
output_obj_grad = self.recv_backward()
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
output_obj_grad = self.send_forward_recv_backward(output_obj)
|
||||||
# Add input_obj and output_obj to end of list.
|
# Add input_obj and output_obj to end of list.
|
||||||
input_objs.append(input_obj)
|
input_objs.append(input_obj)
|
||||||
output_objs.append(output_obj)
|
output_objs.append(output_obj)
|
||||||
|
@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||||
|
|
||||||
if last_iteration:
|
if last_iteration:
|
||||||
input_obj = None
|
self.send_backward(input_obj_grad)
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward()
|
input_obj = self.send_backward_recv_forward(input_obj_grad)
|
||||||
self.send_backward(input_obj_grad)
|
|
||||||
|
|
||||||
# Run cooldown backward passes.
|
# Run cooldown backward passes.
|
||||||
if not forward_only:
|
if not self.forward_only:
|
||||||
for i in range(num_warmup_microbatches):
|
for i in range(num_warmup_microbatches):
|
||||||
input_obj = input_objs.pop(0)
|
input_obj = input_objs.pop(0)
|
||||||
output_obj = output_objs.pop(0)
|
output_obj = output_objs.pop(0)
|
||||||
|
@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||||
self.send_backward(input_obj_grad)
|
self.send_backward(input_obj_grad)
|
||||||
|
|
||||||
|
if not self.forward_only:
|
||||||
|
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||||
|
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
if isinstance(model, ModelWrapper):
|
if isinstance(model, ModelWrapper):
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import contextlib
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -68,45 +69,39 @@ class PipelineStageManager:
|
||||||
# for shardformer, hold model chunk id
|
# for shardformer, hold model chunk id
|
||||||
self.model_chunk_id: Optional[int] = None
|
self.model_chunk_id: Optional[int] = None
|
||||||
|
|
||||||
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
|
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
||||||
"""Is the current stage the first stage.
|
"""Is the current stage the first stage.
|
||||||
|
|
||||||
NOTE:
|
NOTE:
|
||||||
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
|
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
|
||||||
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()
|
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the current stage is the first stage.
|
bool: Whether the current stage is the first stage.
|
||||||
"""
|
"""
|
||||||
if self.is_interleave and model_chunk_id is None:
|
assert isinstance(ignore_chunk, bool)
|
||||||
model_chunk_id = self.model_chunk_id
|
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||||
assert self.is_interleave ^ (
|
if not self.is_interleave or ignore_chunk:
|
||||||
model_chunk_id is None
|
|
||||||
), "model_chunk_id must be specified when using interleaved pipeline"
|
|
||||||
if not self.is_interleave or model_chunk_id < 0:
|
|
||||||
return self.stage == 0
|
return self.stage == 0
|
||||||
else:
|
else:
|
||||||
return self.stage == 0 and model_chunk_id == 0
|
return self.stage == 0 and self.model_chunk_id == 0
|
||||||
|
|
||||||
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
|
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
|
||||||
"""Is the current stage the last stage.
|
"""Is the current stage the last stage.
|
||||||
|
|
||||||
NOTE:
|
NOTE:
|
||||||
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
|
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
|
||||||
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()
|
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the current stage is the last stage.
|
bool: Whether the current stage is the last stage.
|
||||||
"""
|
"""
|
||||||
if self.is_interleave and model_chunk_id is None:
|
assert isinstance(ignore_chunk, bool)
|
||||||
model_chunk_id = self.model_chunk_id
|
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||||
assert self.is_interleave ^ (
|
if not self.is_interleave or ignore_chunk:
|
||||||
model_chunk_id is None
|
|
||||||
), "model_chunk_id must be specified when using interleaved pipeline"
|
|
||||||
if not self.is_interleave or model_chunk_id < 0:
|
|
||||||
return self.stage == self.num_stages - 1
|
return self.stage == self.num_stages - 1
|
||||||
else:
|
else:
|
||||||
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1
|
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_stages(self) -> int:
|
def num_stages(self) -> int:
|
||||||
|
@ -174,3 +169,10 @@ class PipelineStageManager:
|
||||||
ProcessGroup: Process group of the given stages.
|
ProcessGroup: Process group of the given stages.
|
||||||
"""
|
"""
|
||||||
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
|
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def switch_model_chunk_id(self, model_chunk_id: int):
|
||||||
|
old_model_chunk_id = self.model_chunk_id
|
||||||
|
self.model_chunk_id = model_chunk_id
|
||||||
|
yield
|
||||||
|
self.model_chunk_id = old_model_chunk_id
|
||||||
|
|
|
@ -309,11 +309,11 @@ class BertPolicy(Policy):
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
num_stages=stage_manager.num_stages,
|
num_stages=stage_manager.num_stages,
|
||||||
)
|
)
|
||||||
if stage_manager.is_first_stage(-1):
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.embeddings)
|
held_layers.append(module.embeddings)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage(-1):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.pooler)
|
held_layers.append(module.pooler)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||||
"""Get pipeline layers for current stage"""
|
"""Get pipeline layers for current stage"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.cls)
|
held_layers.append(self.model.cls)
|
||||||
|
|
||||||
return held_layers
|
return held_layers
|
||||||
|
@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.cls)
|
held_layers.append(self.model.cls)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.cls)
|
held_layers.append(self.model.cls)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.dropout)
|
held_layers.append(self.model.dropout)
|
||||||
held_layers.append(self.model.classifier)
|
held_layers.append(self.model.classifier)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.dropout)
|
held_layers.append(self.model.dropout)
|
||||||
held_layers.append(self.model.classifier)
|
held_layers.append(self.model.classifier)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.cls)
|
held_layers.append(self.model.cls)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.dropout)
|
held_layers.append(self.model.dropout)
|
||||||
held_layers.append(self.model.classifier)
|
held_layers.append(self.model.classifier)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||||
"""
|
"""
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.qa_outputs)
|
held_layers.append(self.model.qa_outputs)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,11 @@ from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy
|
from ..modeling.llama import (
|
||||||
|
LlamaPipelineForwards,
|
||||||
|
get_llama_flash_attention_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
||||||
|
@ -140,21 +144,42 @@ class LlamaPolicy(Policy):
|
||||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
to customized forward method, and add this changing to policy."""
|
to customized forward method, and add this changing to policy."""
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager is None:
|
||||||
stage_manager = self.pipeline_stage_manager
|
return
|
||||||
if self.model.__class__.__name__ == "LlamaModel":
|
|
||||||
module = self.model
|
|
||||||
else:
|
|
||||||
module = self.model.model
|
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == "LlamaModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
layers_per_stage = self.distribute_layers(
|
||||||
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
|
)
|
||||||
|
stage_manager.stage_indices = Policy.get_stage_index(
|
||||||
|
layers_per_stage,
|
||||||
|
stage_manager.stage,
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)}
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=model_cls
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[Module]:
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
|
@ -167,13 +192,32 @@ class LlamaPolicy(Policy):
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
if stage_manager.is_interleave:
|
||||||
if stage_manager.is_first_stage():
|
assert stage_manager.num_model_chunks is not None
|
||||||
held_layers.append(module.embed_tokens)
|
layers_per_stage = self.distribute_layers(
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
)
|
||||||
if stage_manager.is_last_stage():
|
stage_indices = Policy.get_stage_index(
|
||||||
held_layers.append(module.norm)
|
layers_per_stage,
|
||||||
|
stage_manager.stage,
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM: ModulePolicyDescription(
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||||
suffix="lm_head", target_module=Linear1D_Col
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.score)
|
held_layers.append(self.model.score)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
|
|
@ -57,9 +57,7 @@ def evaluate_model(
|
||||||
|
|
||||||
def evaluate_subset(dataloader: DataLoader):
|
def evaluate_subset(dataloader: DataLoader):
|
||||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
|
||||||
)
|
|
||||||
|
|
||||||
accum_loss = torch.zeros(1, device=get_current_device())
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
|
@ -136,9 +134,7 @@ def train_epoch(
|
||||||
coordinator: DistCoordinator,
|
coordinator: DistCoordinator,
|
||||||
):
|
):
|
||||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
|
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
None if not booster.plugin.stage_manager.is_interleave else -1
|
|
||||||
)
|
|
||||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
|
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
|
||||||
total_step = len(train_dataloader)
|
total_step = len(train_dataloader)
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,9 @@ def main():
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
|
pp_style="interleaved",
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
|
num_model_chunks=2,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
num_microbatches=args.mbs,
|
num_microbatches=args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
|
|
|
@ -1,47 +1,80 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
WORLD_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
def check_p2p_communication():
|
def check_p2p_communication():
|
||||||
pg_mesh = ProcessGroupMesh(2)
|
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||||
p2p = PipelineP2PCommunication(stage_manager)
|
p2p = PipelineP2PCommunication(stage_manager)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
tensor = torch.ones(1, device=get_current_device())
|
tensor = torch.ones(1, device=get_current_device())
|
||||||
|
data = [
|
||||||
|
"tensor",
|
||||||
|
tensor,
|
||||||
|
[tensor],
|
||||||
|
{"tensor": tensor},
|
||||||
|
]
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
p2p.send_forward(tensor)
|
for obj in data:
|
||||||
p2p.send_forward([tensor])
|
p2p.send_forward(obj)
|
||||||
p2p.send_forward({"tensor": tensor})
|
for i in range(len(data)):
|
||||||
else:
|
recv_obj = p2p.send_forward_recv_backward(data[i])
|
||||||
obj = p2p.recv_forward()
|
assert recv_obj == data[-(i + 1)]
|
||||||
assert torch.equal(obj, tensor)
|
elif rank == 1:
|
||||||
obj = p2p.recv_forward()
|
for obj in data:
|
||||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
recv_obj = p2p.recv_forward()
|
||||||
obj = p2p.recv_forward()
|
assert recv_obj == obj
|
||||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
for i in range(len(data)):
|
||||||
|
p2p.send_backward(data[-(i + 1)])
|
||||||
|
recv_obj = p2p.recv_forward()
|
||||||
|
assert recv_obj == data[i]
|
||||||
|
|
||||||
if rank == 1:
|
if rank == 1:
|
||||||
p2p.send_backward(tensor)
|
for obj in data:
|
||||||
p2p.send_backward([tensor])
|
p2p.send_backward(obj)
|
||||||
p2p.send_backward({"tensor": tensor})
|
for i in range(len(data)):
|
||||||
else:
|
recv_obj = p2p.send_backward_recv_forward(data[i])
|
||||||
obj = p2p.recv_backward()
|
assert recv_obj == data[-(i + 1)]
|
||||||
assert torch.equal(obj, tensor)
|
elif rank == 0:
|
||||||
obj = p2p.recv_backward()
|
for obj in data:
|
||||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
recv_obj = p2p.recv_backward()
|
||||||
obj = p2p.recv_backward()
|
assert recv_obj == obj
|
||||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
for i in range(len(data)):
|
||||||
|
recv_obj = p2p.recv_backward()
|
||||||
|
p2p.send_forward(data[-(i + 1)])
|
||||||
|
assert recv_obj == data[i]
|
||||||
|
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
tensor_metadata = TensorMetadata(
|
||||||
|
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||||
|
)
|
||||||
|
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
|
||||||
|
if rank == 0:
|
||||||
|
recv_obj = p2p.send_forward_recv_backward(
|
||||||
|
tensor,
|
||||||
|
send_metadata=False,
|
||||||
|
metadata_recv=comm_metadata,
|
||||||
|
)
|
||||||
|
assert recv_obj == tensor
|
||||||
|
elif rank == 1:
|
||||||
|
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
|
||||||
|
assert recv_obj == tensor
|
||||||
|
p2p.send_backward(tensor, send_metadata=False)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pipeline_p2p():
|
def test_pipeline_p2p():
|
||||||
spawn(run_dist, 2)
|
spawn(run_dist, WORLD_SIZE)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -37,12 +37,13 @@ def pp_linear_fwd(
|
||||||
stage_mgr: PipelineStageManager = None,
|
stage_mgr: PipelineStageManager = None,
|
||||||
model_chunk_id: int = None,
|
model_chunk_id: int = None,
|
||||||
):
|
):
|
||||||
if stage_mgr.is_first_stage(model_chunk_id):
|
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||||
return {"input_obj": forward(data)}
|
if stage_mgr.is_first_stage():
|
||||||
elif stage_mgr.is_last_stage(model_chunk_id):
|
return {"input_obj": forward(data)}
|
||||||
return forward(input_obj)
|
elif stage_mgr.is_last_stage():
|
||||||
else:
|
return forward(input_obj)
|
||||||
return {"input_obj": forward(input_obj)}
|
else:
|
||||||
|
return {"input_obj": forward(input_obj)}
|
||||||
|
|
||||||
|
|
||||||
def run_pp(
|
def run_pp(
|
||||||
|
@ -107,7 +108,7 @@ def run_pp(
|
||||||
)
|
)
|
||||||
|
|
||||||
# check loss
|
# check loss
|
||||||
if stage_manager.is_last_stage(-1):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
# check gradients
|
# check gradients
|
||||||
|
@ -119,6 +120,7 @@ def run_pp(
|
||||||
# step
|
# step
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
pp_optimizer.step()
|
pp_optimizer.step()
|
||||||
|
pp_optimizer.zero_grad()
|
||||||
|
|
||||||
# check updated param
|
# check updated param
|
||||||
for i in range(num_model_chunk):
|
for i in range(num_model_chunk):
|
||||||
|
@ -126,6 +128,24 @@ def run_pp(
|
||||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||||
|
|
||||||
|
# forward only
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_output = torch_model(input_list[0])
|
||||||
|
torch_loss = criterion(torch_output)
|
||||||
|
|
||||||
|
pp_ret = schedule.forward_backward_step(
|
||||||
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||||
|
)
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
|
for layer in sharded_model:
|
||||||
|
if layer.weight.grad is None:
|
||||||
|
assert layer.weight.grad is None and layer.bias.grad is None
|
||||||
|
else:
|
||||||
|
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||||
|
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||||
|
|
|
@ -4,6 +4,7 @@ from types import MethodType
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
|
|
||||||
|
DIM = 8
|
||||||
|
NUM_LAYER = 8
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MlpModel, self).__init__()
|
super().__init__()
|
||||||
self.linear1 = nn.Linear(4, 8)
|
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
|
||||||
self.linear2 = nn.Linear(8, 4)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.linear1(x)
|
for layer in self.layers:
|
||||||
x = self.linear2(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def pp_linear_fwd(
|
def pp_linear_fwd(
|
||||||
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None
|
forward,
|
||||||
|
data: torch.Tensor = None,
|
||||||
|
input_obj: torch.Tensor = None,
|
||||||
|
stage_mgr: PipelineStageManager = None,
|
||||||
):
|
):
|
||||||
if stage_mgr.is_first_stage():
|
if stage_mgr.is_first_stage():
|
||||||
return {"input_obj": forward(data)}
|
return {"input_obj": forward(data)}
|
||||||
|
@ -38,34 +44,45 @@ def pp_linear_fwd(
|
||||||
return {"input_obj": forward(input_obj)}
|
return {"input_obj": forward(input_obj)}
|
||||||
|
|
||||||
|
|
||||||
def examine_pp():
|
def examine_pp(num_microbatch: int, batch_size: int):
|
||||||
"""
|
"""
|
||||||
This test is to examine the correctness of 1F1B, compared with torch.
|
This test is to examine the correctness of 1F1B, compared with torch.
|
||||||
Be aware it contains some hardcodes.
|
Be aware it contains some hardcodes.
|
||||||
"""
|
"""
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
local_rank = torch.distributed.get_rank()
|
dist.get_rank()
|
||||||
seed_all(1453)
|
seed_all(1453)
|
||||||
|
|
||||||
NUM_MICRO_BATCHS = 4
|
|
||||||
BATCH_SIZE = 4
|
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
torch_model = MlpModel().cuda()
|
torch_model = MlpModel().cuda()
|
||||||
|
|
||||||
pp_model = copy.deepcopy(torch_model).cuda()
|
pp_model = copy.deepcopy(torch_model).cuda()
|
||||||
|
|
||||||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
pg_mesh = ProcessGroupMesh(world_size)
|
||||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
|
||||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
|
|
||||||
|
|
||||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
rank = dist.get_rank()
|
||||||
if idx % (world_size) == local_rank:
|
sharded_model = torch.nn.ModuleList()
|
||||||
sharded_model = sub_model.cuda()
|
num_local_layer = NUM_LAYER // world_size
|
||||||
|
for idx, sub_model in enumerate(pp_model.layers):
|
||||||
|
if idx // num_local_layer == rank:
|
||||||
|
sharded_model.append(sub_model.cuda())
|
||||||
|
assert len(sharded_model) == num_local_layer
|
||||||
|
|
||||||
sharded_model._forward = sharded_model.forward
|
def custom_fwd(self, x):
|
||||||
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward)
|
for layer in self._modules.values():
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
sharded_model._forward = MethodType(custom_fwd, sharded_model)
|
||||||
|
sharded_model.forward = MethodType(
|
||||||
|
partial(
|
||||||
|
pp_linear_fwd,
|
||||||
|
stage_mgr=stage_manager,
|
||||||
|
),
|
||||||
|
sharded_model._forward,
|
||||||
|
)
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
|
@ -73,19 +90,15 @@ def examine_pp():
|
||||||
|
|
||||||
# create
|
# create
|
||||||
seed_all(1453)
|
seed_all(1453)
|
||||||
if stage_manager.is_first_stage():
|
input_list = [torch.rand(batch_size, DIM).cuda()]
|
||||||
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
|
dist.all_reduce(input_list[0])
|
||||||
else:
|
|
||||||
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
|
|
||||||
torch.distributed.all_reduce(input_list[0])
|
|
||||||
|
|
||||||
criterion = lambda x, y: torch.mean(x)
|
criterion = lambda x, *arg, **kwargs: (x * x).mean()
|
||||||
|
|
||||||
# forward and backward
|
# forward and backward
|
||||||
torch_output = torch_model(input_list[0])
|
torch_output = torch_model(input_list[0])
|
||||||
torch_loss = criterion(torch_output, _)
|
torch_loss = criterion(torch_output)
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
|
|
||||||
pp_ret = schedule.forward_backward_step(
|
pp_ret = schedule.forward_backward_step(
|
||||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||||
)
|
)
|
||||||
|
@ -95,34 +108,66 @@ def examine_pp():
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
# check gradients
|
# check gradients
|
||||||
torch_grad = []
|
for i in range(len(sharded_model)):
|
||||||
for torch_p in torch_model.parameters():
|
idx = rank * num_local_layer + i
|
||||||
torch_grad.append(torch_p.grad.data)
|
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||||
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
|
|
||||||
|
|
||||||
# step
|
# step
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
pp_optimizer.step()
|
pp_optimizer.step()
|
||||||
|
pp_optimizer.zero_grad()
|
||||||
|
|
||||||
# check updated param
|
# check updated param
|
||||||
torch_param = []
|
for i in range(len(sharded_model)):
|
||||||
for torch_p in torch_model.parameters():
|
idx = rank * num_local_layer + i
|
||||||
torch_param.append(torch_p.data)
|
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||||
for idx, pp_p in enumerate(sharded_model.parameters()):
|
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||||
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
|
|
||||||
|
# forward only
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_output = torch_model(input_list[0])
|
||||||
|
torch_loss = criterion(torch_output)
|
||||||
|
|
||||||
|
pp_ret = schedule.forward_backward_step(
|
||||||
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
||||||
|
)
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
|
for layer in sharded_model:
|
||||||
|
if layer.weight.grad is None:
|
||||||
|
assert layer.weight.grad is None and layer.bias.grad is None
|
||||||
|
else:
|
||||||
|
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||||
|
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
port: int,
|
||||||
|
num_microbatch: int,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
examine_pp()
|
examine_pp(num_microbatch, batch_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("num_microbatch", [4, 12])
|
||||||
|
@pytest.mark.parametrize("batch_size", [12])
|
||||||
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp():
|
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
|
||||||
spawn(run_dist, 2)
|
assert NUM_LAYER % world_size == 0
|
||||||
|
spawn(
|
||||||
|
run_dist,
|
||||||
|
world_size,
|
||||||
|
num_microbatch=num_microbatch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_pp()
|
test_pp(num_microbatch=4, batch_size=4, world_size=4)
|
||||||
|
|
|
@ -203,7 +203,7 @@ def check_output_hidden_state(
|
||||||
):
|
):
|
||||||
org_hidden_state = org_output.last_hidden_state
|
org_hidden_state = org_output.last_hidden_state
|
||||||
|
|
||||||
if stage_manager and stage_manager.is_last_stage():
|
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||||
else:
|
else:
|
||||||
sharded_hidden_state = sharded_output.last_hidden_state
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
|
@ -229,6 +229,10 @@ def check_weight(
|
||||||
org_weight = getattr_(org_model, suffix).weight
|
org_weight = getattr_(org_model, suffix).weight
|
||||||
sharded_weight = getattr_(sharded_model, suffix).weight
|
sharded_weight = getattr_(sharded_model, suffix).weight
|
||||||
|
|
||||||
|
# skip if layer is not held by this process
|
||||||
|
if sharded_weight is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||||
sharded_weight_list = [
|
sharded_weight_list = [
|
||||||
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
|
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
|
||||||
|
|
|
@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
||||||
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||||
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||||
|
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
|
@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||||
col_layer_grads = get_grad_tensors_for_check(
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||||
)
|
)
|
||||||
|
@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
|
|
||||||
# check last hidden state & loss
|
# check last hidden state & loss
|
||||||
if stage_manager is None or stage_manager.is_last_stage():
|
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-5, 1e-3
|
atol, rtol = 1e-5, 1e-3
|
||||||
else:
|
else:
|
||||||
|
@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 5e-3, 1e-3
|
atol, rtol = 5e-3, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
if stage_manager is None or stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||||
|
|
||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
@ -183,6 +184,17 @@ def run_bert_test(test_config):
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_bert_3d_test(test_config):
|
def run_bert_3d_test(test_config):
|
||||||
|
|
|
@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-6, 1e-4
|
atol, rtol = 1e-6, 1e-4
|
||||||
else:
|
else:
|
||||||
|
@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
|
|
||||||
# check last hidden state & loss
|
# check last hidden state & loss
|
||||||
if stage_manager is None or stage_manager.is_last_stage():
|
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-5, 1e-3
|
atol, rtol = 1e-5, 1e-3
|
||||||
else:
|
else:
|
||||||
|
@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# check weights
|
# check weights
|
||||||
if stage_manager is None or stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
|
@ -179,6 +179,17 @@ def run_llama_test(test_config):
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_3d_test(test_config):
|
def run_llama_3d_test(test_config):
|
||||||
|
|
Loading…
Reference in New Issue