mirror of https://github.com/hpcaitech/ColossalAI
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>pull/5785/head^2
parent
4ccaaaab63
commit
2a25a2aff7
|
@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
|
@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
|
|
@ -134,7 +134,7 @@ class ProcessGroupMesh:
|
|||
"""
|
||||
|
||||
assert mode in ["raise", "wrap", "clip"]
|
||||
return np.ravel_multi_index(coord, shape, mode)
|
||||
return int(np.ravel_multi_index(coord, shape, mode))
|
||||
|
||||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
||||
|
@ -182,7 +182,7 @@ class ProcessGroupMesh:
|
|||
axis = [
|
||||
axis,
|
||||
]
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
|
||||
indices_at_axis = [
|
||||
indices_at_axis,
|
||||
]
|
||||
|
|
|
@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
|
|||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
overlap_p2p: bool = True,
|
||||
send_first: bool = True,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
is_send = send_dst is not None and send_tensor_list is not None
|
||||
is_recv = recv_src is not None and buffer_recv is not None
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
if not overlap_p2p:
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return buffer_recv, []
|
||||
else:
|
||||
return buffer_recv, reqs
|
||||
return None, []
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
|
@ -260,10 +270,11 @@ def _send_recv_serialization_object(
|
|||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
send_first: bool = True,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
send_object_size_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
|
@ -274,43 +285,54 @@ def _send_recv_serialization_object(
|
|||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
req.wait() # This blocks the compute stream in torch
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
is_send = send_dst is not None and send_object_tensor is not None
|
||||
is_recv = recv_src is not None and recv_object_size_tensor is not None
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
if is_recv:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
|
@ -328,11 +350,12 @@ def _communicate(
|
|||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
overlap_p2p: bool,
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
@ -341,6 +364,7 @@ def _communicate(
|
|||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
overlap_p2p (bool): whether to overlap p2p communication with computation
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
|
@ -358,32 +382,10 @@ def _communicate(
|
|||
# NOTE: if object contains non-tensor objects, we have to send metadata
|
||||
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
||||
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
||||
else:
|
||||
send_metadata = False
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
else:
|
||||
recv_data = _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
||||
|
||||
|
@ -402,14 +404,25 @@ def _communicate(
|
|||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
assert (
|
||||
metadata_recv is None or _metadata_recv is None
|
||||
), "You shouldn't receive metadata when using the cached metadata"
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
# Send and receive data
|
||||
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
||||
recv_tensor_objs = _batch_send_recv_tensor(
|
||||
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
|
||||
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
|
||||
tensor_objs,
|
||||
recv_tensor_metadata,
|
||||
send_dst,
|
||||
recv_src,
|
||||
send_group,
|
||||
recv_group,
|
||||
current_device,
|
||||
overlap_p2p=overlap_p2p,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
|
||||
if metadata_recv is not None:
|
||||
|
@ -424,33 +437,9 @@ def _communicate(
|
|||
for idx in non_tensor_obj_idx:
|
||||
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
||||
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
||||
return recv_object, wait_handles
|
||||
|
||||
return recv_object
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
dst (int): rank of the destination
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
src (int): source rank of data. local rank will receive data from src rank.
|
||||
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
return None, wait_handles
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
|
@ -532,10 +521,13 @@ def _p2p_comm(
|
|||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
self.overlap_p2p = overlap_p2p
|
||||
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_forward(
|
||||
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -543,95 +535,186 @@ class PipelineP2PCommunication:
|
|||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
input_tensor, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=prev_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_backward(
|
||||
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
|
||||
output_tensor_grad, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=next_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
recv_src=None,
|
||||
send_dst=next_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
recv_src=None,
|
||||
send_dst=prev_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self,
|
||||
output_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the input tensor to the next pipeline stage.
|
||||
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank() if is_send else None
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
output_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
|
||||
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
|
||||
next_rank = self.stage_manager.get_next_rank() if is_recv else None
|
||||
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=prev_rank,
|
||||
recv_src=next_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
next_rank,
|
||||
|
@ -640,28 +723,28 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
prev_rank,
|
||||
|
@ -670,7 +753,8 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
|
@ -679,7 +763,7 @@ class PipelineP2PCommunication:
|
|||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
|
@ -689,12 +773,11 @@ class PipelineP2PCommunication:
|
|||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
self.stage_manager.get_p2p_process_group(),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
|
|||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
self.overlap_p2p = overlap_p2p
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
assert microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
assert (
|
||||
microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not is_forward:
|
||||
# Reverse order
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
return None, []
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank, metadata_recv=self.grad_metadata_recv
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
return output_tensor_grad
|
||||
return None, []
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
model_chunk_id (int): The current model chunk idx.
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
model_chunk_id (int): The current model chunk idx.
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
output_tensor: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_last_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
return self.recv_backward(model_chunk_id_recv)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
input_tensor_grad: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_first_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
return self.recv_forward(model_chunk_id_recv)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
else:
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
|
||||
) -> Tuple[Any, List]:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
is_recv,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_first=send_first,
|
||||
)
|
||||
# Cache metadata
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
|
||||
) -> Tuple[Any, List]:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
is_recv,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_first=send_first,
|
||||
)
|
||||
# Cache metadata
|
||||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
|
@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
|
@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
fwd_wait_handles = []
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
|
||||
|
||||
for i in range(self.num_microbatch * self.num_model_chunks):
|
||||
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
last_batch = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
if not last_batch:
|
||||
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
# Forward + until 1st backward
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
# Steps needed to reach the last chunk
|
||||
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
|
||||
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
|
||||
|
@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
bwd_wait_handles = []
|
||||
# Get the 1st input batch
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
last_iteration = i == num_warmup_microbatch - 1
|
||||
last_batch = i == num_warmup_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
if last_iteration and num_microbatch_remaining == 0:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if last_batch and num_microbatch_remaining == 0:
|
||||
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
fwd_batch_id = i + num_warmup_microbatch
|
||||
last_batch = i == num_microbatch_remaining - 1
|
||||
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
# Pop output_obj and output_obj from the start of the list for the backward pass.
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
# NOTE: perform 2x communication for forward and backward
|
||||
def send_forward_recv_backward():
|
||||
if last_iteration and num_microbatch == num_microbatch_remaining:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
# Helper functions
|
||||
def send_forward_recv_forward():
|
||||
if last_batch:
|
||||
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
|
||||
wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
return None, wait_handles
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_obj, wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0
|
||||
and i > 0, # Receive from warmup stage first in the first batch
|
||||
)
|
||||
return output_obj_grad
|
||||
return input_obj, wait_handles
|
||||
|
||||
def send_backward_recv_forward():
|
||||
if last_iteration:
|
||||
def send_backward_recv_backward():
|
||||
no_cooldown = num_microbatch == num_microbatch_remaining
|
||||
if last_batch and no_cooldown:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
|
||||
return None, wait_handles
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
output_obj_grad, wait_handles = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
return input_obj
|
||||
return output_obj_grad, wait_handles
|
||||
|
||||
if self.stage_manager.stage % 2 == 0:
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj = send_backward_recv_forward()
|
||||
else:
|
||||
input_obj = send_backward_recv_forward()
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
# however in practice this works fine, and Megatron does this too
|
||||
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
|
||||
# if deadlock, call _wait_p2p(fwd_wait_handles) here
|
||||
output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
|
||||
|
||||
if num_microbatch_remaining == 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
last_iteration = i == num_microbatch - 1
|
||||
last_batch = i == num_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
output_obj_grad = self.send_backward_recv_backward(
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
)
|
||||
assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
_ = self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
|
|
|
@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
|
@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
|
@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
send_first = None
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
|
@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
send_first = None # must not fallback
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
|
@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
|
|
|
@ -35,7 +35,7 @@ class PipelineStageManager:
|
|||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||
if num_layers_per_stage is not None:
|
||||
assert len(num_layers_per_stage) == self.num_stages
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
@ -48,30 +48,14 @@ class PipelineStageManager:
|
|||
# the next rank of the last rank is rank0
|
||||
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
||||
|
||||
# init p2p process groups
|
||||
stages = list(range(self.num_stages))
|
||||
for prev, cur in zip(stages[:-1], stages[1:]):
|
||||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
|
||||
if self.stage in [prev, cur]:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
self.is_interleave = enable_interleave
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
|
||||
if self.stage in [stages[0], stages[-1]]:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis)
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
|
@ -184,19 +168,12 @@ class PipelineStageManager:
|
|||
"""
|
||||
return self.next_rank
|
||||
|
||||
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
|
||||
def get_p2p_process_group(self) -> ProcessGroup:
|
||||
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
|
||||
|
||||
Args:
|
||||
first_rank (int): The first rank.
|
||||
second_rank (int): The second rank.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: P2P process group between the two ranks.
|
||||
"""
|
||||
if first_rank > second_rank:
|
||||
first_rank, second_rank = second_rank, first_rank
|
||||
return self.p2p_groups[(first_rank, second_rank)]
|
||||
return self.p2p_group
|
||||
|
||||
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
|
||||
"""Get the process group of the given stages.
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import argparse
|
||||
import resource
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
|
@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"100m": LlamaConfig(
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=2048,
|
||||
hidden_size=1024,
|
||||
),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
|
@ -58,6 +68,9 @@ def main():
|
|||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
|
@ -78,11 +91,13 @@ def main():
|
|||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
|
||||
parser.add_argument(
|
||||
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
|
||||
)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
|
@ -98,6 +113,7 @@ def main():
|
|||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
"pp_style": "interleaved",
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else {}
|
||||
|
@ -174,6 +190,8 @@ def main():
|
|||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
sp_size=args.sp,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
|
@ -182,12 +200,16 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
|
@ -195,6 +217,7 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -210,10 +233,11 @@ def main():
|
|||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
|
@ -251,6 +275,7 @@ def main():
|
|||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||
|
@ -269,15 +294,19 @@ def main():
|
|||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
else:
|
||||
|
@ -288,6 +317,7 @@ def main():
|
|||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
|
|
|
@ -15,8 +15,7 @@ WORLD_SIZE = 2
|
|||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_accelerator().get_current_device())
|
||||
|
@ -31,41 +30,40 @@ def check_p2p_communication():
|
|||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
recv_obj, _ = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=create_send_metadata(tensor),
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def check_stage_manager():
|
|||
# check p2p groups
|
||||
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
|
||||
if rank in [prev, cur]:
|
||||
group = stage_manager.get_p2p_process_group(prev, cur)
|
||||
group = stage_manager.get_p2p_process_group()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# check stage groups
|
||||
|
|
Loading…
Reference in New Issue