[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test
pull/5190/head
Wenhao Chen 2024-01-03 11:34:49 +08:00 committed by Xuanlei Zhao
parent 931d0e0731
commit 196b85368b
5 changed files with 269 additions and 136 deletions

View File

@ -1,5 +1,4 @@
import ctypes import ctypes
import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
@ -23,7 +22,6 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
@ -984,13 +982,6 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
if os.getenv("NCCL_BUFFSIZE") is None:
logger = get_dist_logger()
logger.warning(
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
)
os.environ["NCCL_BUFFSIZE"] = "134217728"
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert ( assert (

View File

@ -344,6 +344,7 @@ def _communicate(
recv_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True, send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None, metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
""" """
Send and receive object from send_dst and recv_src respectively Send and receive object from send_dst and recv_src respectively
@ -368,8 +369,14 @@ def _communicate(
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) assert send_prior_fallback is not None, "Priority must be set if fallback happens"
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
else:
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid: # NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv() # 1. send() [needs extra metadata] and no recv()
@ -437,7 +444,7 @@ def _communicate(
raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None: def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank """send anything to dst rank
Args: Args:
@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta
Returns: Returns:
None None
""" """
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any: def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src """recv anything from src
Args: Args:
@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona
Returns: Returns:
Any: Object received from src. Any: Object received from src.
""" """
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv) return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
def _p2p_comm( def _p2p_comm(
@ -557,7 +564,10 @@ class PipelineP2PCommunication:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object( input_tensor = _recv_object(
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
metadata_recv=metadata_recv,
) )
return input_tensor return input_tensor
@ -575,7 +585,10 @@ class PipelineP2PCommunication:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object( output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
metadata_recv=metadata_recv,
) )
return output_tensor_grad return output_tensor_grad
@ -595,7 +608,7 @@ class PipelineP2PCommunication:
cur_rank, cur_rank,
next_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank), self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata, send_metadata=send_metadata,
) )
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
@ -613,7 +626,7 @@ class PipelineP2PCommunication:
cur_rank, cur_rank,
prev_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata, send_metadata=send_metadata,
) )
def send_forward_recv_backward( def send_forward_recv_backward(
@ -622,6 +635,7 @@ class PipelineP2PCommunication:
next_rank: Optional[int] = None, next_rank: Optional[int] = None,
send_metadata: bool = True, send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None, metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
@ -642,6 +656,7 @@ class PipelineP2PCommunication:
recv_group=group, recv_group=group,
send_metadata=send_metadata, send_metadata=send_metadata,
metadata_recv=metadata_recv, metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
def send_backward_recv_forward( def send_backward_recv_forward(
@ -650,6 +665,7 @@ class PipelineP2PCommunication:
prev_rank: Optional[int] = None, prev_rank: Optional[int] = None,
send_metadata: bool = True, send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None, metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
@ -670,6 +686,7 @@ class PipelineP2PCommunication:
recv_group=group, recv_group=group,
send_metadata=send_metadata, send_metadata=send_metadata,
metadata_recv=metadata_recv, metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
def p2p_communicate( def p2p_communicate(

View File

@ -41,10 +41,10 @@ class InterleavedSchedule(PipelineSchedule):
# P2PMeta cache # P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True self.send_tensor_metadata = True
self.send_metadata_backward = True self.send_grad_metadata = True
self.metadata_recv_forward = None self.tensor_metadata_recv = None
self.metadata_recv_backward = None self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -77,10 +77,10 @@ class InterleavedSchedule(PipelineSchedule):
# NOTE: disable metadata cache when batch size changes (not valid anymore) # NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size: if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False self.enable_metadata_cache = False
self.send_metadata_forward = True self.send_tensor_metadata = True
self.send_metadata_backward = True self.send_grad_metadata = True
self.metadata_recv_forward = None self.tensor_metadata_recv = None
self.metadata_recv_backward = None self.grad_metadata_recv = None
self.last_batch_size = self.batch_size self.last_batch_size = self.batch_size
@ -108,7 +108,8 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
int: The model chunk idx of the input microbatch_id int: The model chunk idx of the input microbatch_id
""" """
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) assert microbatch_id < self.num_microbatch * self.num_model_chunks
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward: if not is_forward:
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
@ -127,9 +128,9 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.metadata_recv_forward is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -146,13 +147,13 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.metadata_recv_backward is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
@ -163,10 +164,10 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_metadata_forward = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
@ -177,42 +178,96 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_metadata_backward = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward( def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
if not self.stage_manager.is_last_stage(): send_data = not self.stage_manager.is_last_stage()
output_tensor_grad = self.comm.send_forward_recv_backward( with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
output_object, recv_data = not self.stage_manager.is_last_stage()
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward( def send_backward_recv_forward(
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
if not self.stage_manager.is_first_stage(): send_data = not self.stage_manager.is_first_stage()
input_tensor = self.comm.send_backward_recv_forward( with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
output_object, recv_data = not self.stage_manager.is_first_stage()
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
return input_tensor
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
def forward_step( def forward_step(
self, self,
@ -321,12 +376,23 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device()) accum_loss = torch.scalar_tensor(0, device=get_current_device())
# Run warmup forward passes. model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks): for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True) model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not last_iteration:
input_obj = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
@ -364,54 +430,102 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device()) accum_loss = torch.scalar_tensor(0, device=get_current_device())
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatch): for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True) model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if last_iteration and num_microbatch_remaining == 0:
self.send_forward(model_chunk_id, output_obj)
else:
input_obj = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0: if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True) model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
input_obj = self.recv_forward(model_chunk_id) output_obj_grad = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatch_remaining): for i in range(num_microbatch_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
last_iteration = i == num_microbatch_remaining - 1 last_iteration = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
)
return output_obj_grad
def send_backward_recv_forward():
if last_iteration:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
)
return input_obj
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id) output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run cooldown backward passes. # Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch): for i in range(num_microbatch_remaining, num_microbatch):
last_iteration = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0) _input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id) # output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

View File

@ -54,10 +54,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# P2PMeta cache # P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True self.send_tensor_metadata = True
self.send_metadata_backward = True self.send_grad_metadata = True
self.metadata_recv_forward = None self.tensor_metadata_recv = None
self.metadata_recv_backward = None self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -90,11 +90,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# NOTE: disable metadata cache when batch size changes (not valid anymore) # NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size: if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False self.enable_metadata_cache = False
self.send_metadata_forward = True self.send_tensor_metadata = True
self.send_metadata_backward = True self.send_grad_metadata = True
self.metadata_recv_forward = None self.tensor_metadata_recv = None
self.metadata_recv_backward = None self.grad_metadata_recv = None
self.last_batch_size = self.batch_size self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any: def load_micro_batch(self) -> Any:
@ -119,9 +119,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.metadata_recv_forward is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -136,13 +136,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.metadata_recv_backward is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None: def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For 1F1B. For 1F1B.
@ -151,10 +151,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_metadata_forward = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For 1F1B. For 1F1B.
@ -163,10 +163,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_metadata_backward = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B. For 1F1B.
@ -175,19 +177,24 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward( output_tensor_grad = self.comm.send_forward_recv_backward(
output_object, output_tensor,
next_rank, next_rank,
send_metadata=self.send_metadata_forward, send_metadata=self.send_tensor_metadata,
metadata_recv=self.metadata_recv_backward, metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
self.send_metadata_forward = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B. For 1F1B.
@ -196,15 +203,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor. prev_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward( input_tensor = self.comm.send_backward_recv_forward(
output_object, input_tensor_grad,
prev_rank, prev_rank,
send_metadata=self.send_metadata_backward, send_metadata=self.send_grad_metadata,
metadata_recv=self.metadata_recv_forward, metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
self.send_metadata_backward = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -365,7 +375,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
output_obj_grad = self.send_forward_recv_backward(output_obj) output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
@ -379,7 +391,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration: if last_iteration:
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
else: else:
input_obj = self.send_backward_recv_forward(input_obj_grad) input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes. # Run cooldown backward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):

View File

@ -1,5 +1,3 @@
import warnings
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -33,7 +31,7 @@ def check_p2p_communication():
for obj in data: for obj in data:
p2p.send_forward(obj) p2p.send_forward(obj)
for i in range(len(data)): for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i]) recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
assert recv_obj == data[-(i + 1)] assert recv_obj == data[-(i + 1)]
elif rank == 1: elif rank == 1:
for obj in data: for obj in data:
@ -48,7 +46,7 @@ def check_p2p_communication():
for obj in data: for obj in data:
p2p.send_backward(obj) p2p.send_backward(obj)
for i in range(len(data)): for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i]) recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
assert recv_obj == data[-(i + 1)] assert recv_obj == data[-(i + 1)]
elif rank == 0: elif rank == 0:
for obj in data: for obj in data:
@ -59,7 +57,6 @@ def check_p2p_communication():
p2p.send_forward(data[-(i + 1)]) p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i] assert recv_obj == data[i]
warnings.filterwarnings("error")
tensor_metadata = TensorMetadata( tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
) )