Browse Source

[Feature] optimize PP overlap (#5735)

* update to fully overlap, still debugging

* improve interface

* fixed deadlock bug

* debug NaN loss

* (experimental) use one comm group for send_fw_recv_fw to fix NaN

* cleaned up interfaces; use one batch p2p for all

* clean up; removed the double p2p batch case

* p2p test passsed

* improve overlap: send fwd before backward

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tentatively use 2 p2p batches

* remove two p2p batches

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pp.sh

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
pull/5785/head^2
Edenzzzz 5 months ago committed by GitHub
parent
commit
2a25a2aff7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 4
      colossalai/cluster/process_group_mesh.py
  3. 361
      colossalai/pipeline/p2p.py
  4. 300
      colossalai/pipeline/schedule/interleaved_pp.py
  5. 35
      colossalai/pipeline/schedule/one_f_one_b.py
  6. 39
      colossalai/pipeline/stage_manager.py
  7. 44
      examples/language/llama/benchmark.py
  8. 20
      tests/test_pipeline/test_p2p_communication.py
  9. 2
      tests/test_pipeline/test_stage_manager.py

8
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
"""
def __init__(
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
) -> None:
super().__init__()
assert (
@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(

4
colossalai/cluster/process_group_mesh.py

@ -134,7 +134,7 @@ class ProcessGroupMesh:
"""
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
@ -182,7 +182,7 @@ class ProcessGroupMesh:
axis = [
axis,
]
assert isinstance(indices_at_axis[0], int)
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
indices_at_axis = [
indices_at_axis,
]

361
colossalai/pipeline/p2p.py

@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
overlap_p2p: bool = True,
send_first: bool = True,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None and send_tensor_list is not None:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
is_send = send_dst is not None and send_tensor_list is not None
is_recv = recv_src is not None and buffer_recv is not None
if send_first:
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize()
return buffer_recv
if not overlap_p2p:
for req in reqs:
req.wait()
return buffer_recv, []
else:
return buffer_recv, reqs
return None, []
def _send_recv_serialization_object(
@ -260,10 +270,11 @@ def _send_recv_serialization_object(
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
send_first: bool = True,
) -> Optional[P2PMetadata]:
ops = []
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
@ -274,43 +285,54 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
req.wait() # This blocks the compute stream in torch
ops = []
if send_dst is not None and send_object_tensor is not None:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
is_send = send_dst is not None and send_object_tensor is not None
is_recv = recv_src is not None and recv_object_size_tensor is not None
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
if is_recv:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
@ -328,11 +350,12 @@ def _communicate(
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
overlap_p2p: bool,
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
send_first: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
@ -341,6 +364,7 @@ def _communicate(
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
overlap_p2p (bool): whether to overlap p2p communication with computation
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
@ -358,32 +382,10 @@ def _communicate(
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
else:
send_metadata = False
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
@ -402,14 +404,25 @@ def _communicate(
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
send_first=send_first if send_first != None else True,
)
assert metadata_recv is None or _metadata_recv is None
assert (
metadata_recv is None or _metadata_recv is None
), "You shouldn't receive metadata when using the cached metadata"
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
tensor_objs,
recv_tensor_metadata,
send_dst,
recv_src,
send_group,
recv_group,
current_device,
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
@ -424,33 +437,9 @@ def _communicate(
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object, wait_handles
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank
Args:
object (Any): object needed to be sent
dst (int): rank of the destination
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src
Args:
src (int): source rank of data. local rank will receive data from src rank.
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
return None, wait_handles
def _p2p_comm(
@ -532,10 +521,13 @@ def _p2p_comm(
class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None:
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
self.stage_manager = stage_manager
self.overlap_p2p = overlap_p2p
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_forward(
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@ -543,95 +535,186 @@ class PipelineP2PCommunication:
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
input_tensor, wait_handles = _communicate(
object=None,
recv_src=prev_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return input_tensor
return input_tensor, wait_handles
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_backward(
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
output_tensor_grad, wait_handles = _communicate(
object=None,
recv_src=next_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return output_tensor_grad
return output_tensor_grad, wait_handles
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
recv_src=None,
send_dst=next_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
recv_src=None,
send_dst=prev_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_forward_recv_backward(
def send_forward_recv_forward(
self,
output_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
Args:
output_object (Any): Object to be sent.
is_send (bool): Whether to send the input tensor to the next pipeline stage.
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank() if is_send else None
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
output_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_backward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
) -> Tuple[Any, List]:
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
Args:
input_object (Any): Object to be sent.
next_rank (int, optional): The rank of the sender and recipient of the tensor
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
next_rank = self.stage_manager.get_next_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate(
input_object,
send_dst=prev_rank,
recv_src=next_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_forward_recv_backward(
self,
input_object: Any,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
Args:
input_object (Any): Object to be sent.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
next_rank,
@ -640,28 +723,28 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
prev_rank,
@ -670,7 +753,8 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def p2p_communicate(
@ -679,7 +763,7 @@ class PipelineP2PCommunication:
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None:
) -> Any:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@ -689,12 +773,11 @@ class PipelineP2PCommunication:
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(
output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
self.stage_manager.get_p2p_process_group(),
comm_dtype,
)
return recv_tensor

300
colossalai/pipeline/schedule/interleaved_pp.py

@ -1,8 +1,9 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
class InterleavedSchedule(PipelineSchedule):
def __init__(
self,
@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
self.overlap_p2p = overlap_p2p
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert microbatch_id < self.num_microbatch * self.num_model_chunks
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor, wait_handles
return None, []
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
return output_tensor_grad
return None, []
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
return []
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
return send_handles
return []
def send_forward_recv_backward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_last_stage()
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_last_stage()
if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
is_recv = not self.stage_manager.is_first_stage()
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
is_recv,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.tensor_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_tensor_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor, wait_handles
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_first_stage()
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_first_stage()
if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
return input_tensor
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
is_recv = not self.stage_manager.is_last_stage()
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
is_recv,
send_metadata=self.send_grad_metadata,
metadata_recv=self.grad_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
def forward_step(
self,
@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
fwd_wait_handles = []
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
last_batch = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait until current input is received
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_iteration:
input_obj = self.send_forward_recv_forward(
if not last_batch:
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
# Forward + until 1st backward
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
# Steps needed to reach the last chunk
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
bwd_wait_handles = []
# Get the 1st input batch
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
last_batch = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait for input
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
if last_iteration and num_microbatch_remaining == 0:
self.send_forward(model_chunk_id, output_obj)
if last_batch and num_microbatch_remaining == 0:
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
else:
input_obj = self.send_forward_recv_forward(
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
last_iteration = i == num_microbatch_remaining - 1
fwd_batch_id = i + num_warmup_microbatch
last_batch = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
# Wait for input.
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
# Helper functions
def send_forward_recv_forward():
if last_batch:
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
wait_handles = self.send_forward(model_chunk_id, output_obj)
return None, wait_handles
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_obj, wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0
and i > 0, # Receive from warmup stage first in the first batch
)
return output_obj_grad
return input_obj, wait_handles
def send_backward_recv_forward():
if last_iteration:
def send_backward_recv_backward():
no_cooldown = num_microbatch == num_microbatch_remaining
if last_batch and no_cooldown:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
return None, wait_handles
else:
input_obj = self.send_backward_recv_forward(
output_obj_grad, wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
send_first=self.stage_manager.stage % 2 == 0,
)
return input_obj
return output_obj_grad, wait_handles
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
# however in practice this works fine, and Megatron does this too
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
# if deadlock, call _wait_p2p(fwd_wait_handles) here
output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch):
last_iteration = i == num_microbatch - 1
last_batch = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
# output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch:
output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
_ = self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

35
colossalai/pipeline/schedule/one_f_one_b.py

@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input tensor or input tensor list.
"""
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input gradient tensor or gradient tensor list.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
send_first = None
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
return output_tensor_grad
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
send_first = None # must not fallback
input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.

39
colossalai/pipeline/stage_manager.py

@ -35,7 +35,7 @@ class PipelineStageManager:
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
@ -48,30 +48,14 @@ class PipelineStageManager:
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
# init p2p process groups
stages = list(range(self.num_stages))
for prev, cur in zip(stages[:-1], stages[1:]):
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
if self.stage in [prev, cur]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis)
def get_stage_index(
self,
@ -184,19 +168,12 @@ class PipelineStageManager:
"""
return self.next_rank
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
def get_p2p_process_group(self) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
Args:
first_rank (int): The first rank.
second_rank (int): The second rank.
Returns:
ProcessGroup: P2P process group between the two ranks.
"""
if first_rank > second_rank:
first_rank, second_rank = second_rank, first_rank
return self.p2p_groups[(first_rank, second_rank)]
return self.p2p_group
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.

44
examples/language/llama/benchmark.py

@ -1,9 +1,11 @@
import argparse
import resource
import time
import warnings
from contextlib import nullcontext
import torch
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
# ==============================
# Constants
# ==============================
MODEL_CONFIGS = {
"100m": LlamaConfig(
max_position_embeddings=4096,
num_hidden_layers=4,
num_attention_heads=32,
intermediate_size=2048,
hidden_size=1024,
),
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
@ -58,6 +68,9 @@ def main():
default="gemini",
help="Choose which plugin to use",
)
parser.add_argument(
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
)
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@ -78,11 +91,13 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
parser.add_argument(
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -98,6 +113,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}
@ -174,6 +190,8 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
enable_sequence_parallelism=args.sp > 1,
@ -182,12 +200,16 @@ def main():
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
@ -195,6 +217,7 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -210,10 +233,11 @@ def main():
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
# ==============================
# Initialize Model and Optimizer
@ -251,6 +275,7 @@ def main():
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
@ -269,15 +294,19 @@ def main():
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
booster.execute_pipeline(
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=False,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
@ -288,6 +317,7 @@ def main():
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()

20
tests/test_pipeline/test_p2p_communication.py

@ -15,8 +15,7 @@ WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager)
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_accelerator().get_current_device())
@ -31,41 +30,40 @@ def check_p2p_communication():
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
recv_obj = p2p.recv_backward()
recv_obj, _ = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
assert recv_obj == data[i]
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
recv_obj, _ = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)

2
tests/test_pipeline/test_stage_manager.py

@ -52,7 +52,7 @@ def check_stage_manager():
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
group = stage_manager.get_p2p_process_group()
dist.barrier(group=group)
# check stage groups

Loading…
Cancel
Save