mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)
* fix: add fallback order option and update 1f1b * fix: fix deadlock comm in interleaved pp * test: modify p2p testpull/4976/merge
parent
3c0d82b19b
commit
d799a3088f
|
@ -1,5 +1,4 @@
|
|||
import ctypes
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
@ -23,7 +22,6 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
|
|||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -984,13 +982,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
if os.getenv("NCCL_BUFFSIZE") is None:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
|
||||
)
|
||||
os.environ["NCCL_BUFFSIZE"] = "134217728"
|
||||
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
|
|
|
@ -344,6 +344,7 @@ def _communicate(
|
|||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
@ -368,8 +369,14 @@ def _communicate(
|
|||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
else:
|
||||
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
|
@ -437,7 +444,7 @@ def _communicate(
|
|||
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
|
@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
|
@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona
|
|||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
|
@ -557,7 +564,10 @@ class PipelineP2PCommunication:
|
|||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(
|
||||
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
@ -575,7 +585,10 @@ class PipelineP2PCommunication:
|
|||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
@ -595,7 +608,7 @@ class PipelineP2PCommunication:
|
|||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
send_metadata,
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
|
@ -613,7 +626,7 @@ class PipelineP2PCommunication:
|
|||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
send_metadata,
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(
|
||||
|
@ -622,6 +635,7 @@ class PipelineP2PCommunication:
|
|||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
|
@ -642,6 +656,7 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
|
@ -650,6 +665,7 @@ class PipelineP2PCommunication:
|
|||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
|
@ -670,6 +686,7 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
|
|
|
@ -41,10 +41,10 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
@ -77,10 +77,10 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
|
@ -108,7 +108,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
assert microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not is_forward:
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
|
@ -127,9 +128,9 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.enable_metadata_cache and self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -146,13 +147,13 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.enable_metadata_cache and self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -163,10 +164,10 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = not self.enable_metadata_cache
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -177,42 +178,96 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = not self.enable_metadata_cache
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
output_tensor: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
)
|
||||
self.send_metadata_forward = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_last_stage()
|
||||
|
||||
return output_tensor_grad
|
||||
if send_data and recv_data:
|
||||
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
return self.recv_backward(model_chunk_id_recv)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
input_tensor_grad: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
)
|
||||
self.send_metadata_backward = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_first_stage()
|
||||
|
||||
return input_tensor
|
||||
if send_data and recv_data:
|
||||
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
return self.recv_forward(model_chunk_id_recv)
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
else:
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
|
@ -321,12 +376,23 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
# Run warmup forward passes.
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
for i in range(self.num_microbatch * self.num_model_chunks):
|
||||
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
@ -364,54 +430,102 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
last_iteration = i == num_warmup_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if last_iteration and num_microbatch_remaining == 0:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
# Pop output_obj and output_obj from the start of the list for the backward pass.
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
# NOTE: perform 2x communication for forward and backward
|
||||
def send_forward_recv_backward():
|
||||
if last_iteration and num_microbatch == num_microbatch_remaining:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
output_tensor=output_obj,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
return output_obj_grad
|
||||
|
||||
def send_backward_recv_forward():
|
||||
if last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
|
||||
)
|
||||
return input_obj
|
||||
|
||||
if self.stage_manager.stage % 2 == 0:
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj = send_backward_recv_forward()
|
||||
else:
|
||||
input_obj = send_backward_recv_forward()
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
|
||||
if num_microbatch_remaining == 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
|
||||
# backward
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
last_iteration = i == num_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
output_obj_grad = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
)
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
|
|
|
@ -54,10 +54,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
@ -90,11 +90,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
|
@ -119,9 +119,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.enable_metadata_cache and self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -136,13 +136,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.enable_metadata_cache and self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -151,10 +151,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = not self.enable_metadata_cache
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -163,10 +163,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = not self.enable_metadata_cache
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -175,19 +177,24 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_metadata_forward = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_backward_recv_forward(
|
||||
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -196,15 +203,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_metadata_backward = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
@ -365,7 +375,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
output_obj_grad = self.send_forward_recv_backward(output_obj)
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
@ -379,7 +391,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if last_iteration:
|
||||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(input_obj_grad)
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -33,7 +31,7 @@ def check_p2p_communication():
|
|||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i])
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
|
@ -48,7 +46,7 @@ def check_p2p_communication():
|
|||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i])
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
|
@ -59,7 +57,6 @@ def check_p2p_communication():
|
|||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
tensor_metadata = TensorMetadata(
|
||||
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue