diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ea74f75f4..c52de0ba7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,4 @@ import ctypes -import os import random from contextlib import contextmanager from functools import partial @@ -23,7 +22,6 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -984,13 +982,6 @@ class HybridParallelPlugin(PipelinePluginBase): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - if os.getenv("NCCL_BUFFSIZE") is None: - logger = get_dist_logger() - logger.warning( - "Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen." - ) - os.environ["NCCL_BUFFSIZE"] = "134217728" - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index cdb7a6a1e..d32ff501f 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -344,6 +344,7 @@ def _communicate( recv_group: Optional[ProcessGroup] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """ Send and receive object from send_dst and recv_src respectively @@ -368,8 +369,14 @@ def _communicate( # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + assert send_prior_fallback is not None, "Priority must be set if fallback happens" + if send_prior_fallback: + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + else: + recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return recv_data # NOTE: only the following 5 cases are valid: # 1. send() [needs extra metadata] and no recv() @@ -437,7 +444,7 @@ def _communicate( raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None: +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: """send anything to dst rank Args: @@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta Returns: None """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) + _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs) -def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any: +def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any: """recv anything from src Args: @@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona Returns: Any: Object received from src. """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv) + return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs) def _p2p_comm( @@ -557,7 +564,10 @@ class PipelineP2PCommunication: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() input_tensor = _recv_object( - prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv + prev_rank, + cur_rank, + self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), + metadata_recv=metadata_recv, ) return input_tensor @@ -575,7 +585,10 @@ class PipelineP2PCommunication: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() output_tensor_grad = _recv_object( - next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv + next_rank, + cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank), + metadata_recv=metadata_recv, ) return output_tensor_grad @@ -595,7 +608,7 @@ class PipelineP2PCommunication: cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank), - send_metadata, + send_metadata=send_metadata, ) def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: @@ -613,7 +626,7 @@ class PipelineP2PCommunication: cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), - send_metadata, + send_metadata=send_metadata, ) def send_forward_recv_backward( @@ -622,6 +635,7 @@ class PipelineP2PCommunication: next_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline @@ -642,6 +656,7 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, + send_prior_fallback=send_prior_fallback, ) def send_backward_recv_forward( @@ -650,6 +665,7 @@ class PipelineP2PCommunication: prev_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline @@ -670,6 +686,7 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, + send_prior_fallback=send_prior_fallback, ) def p2p_communicate( diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 23b3f4e6c..aa18a8520 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -41,10 +41,10 @@ class InterleavedSchedule(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -77,10 +77,10 @@ class InterleavedSchedule(PipelineSchedule): # NOTE: disable metadata cache when batch size changes (not valid anymore) if self.batch_size != self.last_batch_size: self.enable_metadata_cache = False - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -108,7 +108,8 @@ class InterleavedSchedule(PipelineSchedule): Returns: int: The model chunk idx of the input microbatch_id """ - microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + assert microbatch_id < self.num_microbatch * self.num_model_chunks + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: model_chunk_id = self.num_model_chunks - model_chunk_id - 1 @@ -127,9 +128,9 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor @@ -146,13 +147,13 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -163,10 +164,10 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = not self.enable_metadata_cache + self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + self.send_tensor_metadata = not self.enable_metadata_cache - def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -177,42 +178,96 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = not self.enable_metadata_cache + self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + self.send_grad_metadata = not self.enable_metadata_cache def send_forward_recv_backward( - self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None + self, + model_chunk_id_send: int, + model_chunk_id_recv: int, + output_tensor: Any, + next_rank: Optional[int] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, - next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, - ) - self.send_metadata_forward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + send_data = not self.stage_manager.is_last_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + recv_data = not self.stage_manager.is_last_stage() - return output_tensor_grad + if send_data and recv_data: + if not self.send_forward_recv_backward and self.grad_metadata_recv is not None: + send_prior_fallback = None # must not fallback + output_tensor_grad = self.comm.send_forward_recv_backward( + output_tensor, + next_rank, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.grad_metadata_recv, + send_prior_fallback=send_prior_fallback, + ) + self.send_tensor_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) + return output_tensor_grad + + # send only or recv only + self.send_forward(model_chunk_id_send, output_tensor) + return self.recv_backward(model_chunk_id_recv) def send_backward_recv_forward( - self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None + self, + model_chunk_id_send: int, + model_chunk_id_recv: int, + input_tensor_grad: Any, + prev_rank: Optional[int] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.send_backward_recv_forward( - output_object, - prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, - ) - self.send_metadata_backward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + send_data = not self.stage_manager.is_first_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + recv_data = not self.stage_manager.is_first_stage() - return input_tensor + if send_data and recv_data: + if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None: + send_prior_fallback = None # must not fallback + input_tensor = self.comm.send_backward_recv_forward( + input_tensor_grad, + prev_rank, + send_metadata=self.send_grad_metadata, + metadata_recv=self.tensor_metadata_recv, + send_prior_fallback=send_prior_fallback, + ) + self.send_grad_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) + return input_tensor + + # send only or recv only + self.send_backward(model_chunk_id_send, input_tensor_grad) + return self.recv_forward(model_chunk_id_recv) + + def send_forward_recv_forward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool + ): + if send_prior: + self.send_forward(model_chunk_id_send, output_tensor) + input_tensor = self.recv_forward(model_chunk_id_recv) + else: + input_tensor = self.recv_forward(model_chunk_id_recv) + self.send_forward(model_chunk_id_send, output_tensor) + + return input_tensor + + def send_backward_recv_backward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool + ): + if send_prior: + self.send_backward(model_chunk_id_send, input_tensor_grad) + output_tensor_grad = self.recv_backward(model_chunk_id_recv) + else: + output_tensor_grad = self.recv_backward(model_chunk_id_recv) + self.send_backward(model_chunk_id_send, input_tensor_grad) + + return output_tensor_grad def forward_step( self, @@ -321,12 +376,23 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) - # Run warmup forward passes. + model_chunk_id = self.get_model_chunk_id(0, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) + for i in range(self.num_microbatch * self.num_model_chunks): + last_iteration = i == self.num_microbatch * self.num_model_chunks - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.send_forward_recv_forward( + model_chunk_id_send=model_chunk_id, + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), + output_tensor=output_obj, + send_prior=self.stage_manager.stage % 2 == 0, + ) + else: + self.send_forward(model_chunk_id, output_obj) if outputs is not None: outputs = merge_batch(outputs) @@ -364,54 +430,102 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + model_chunk_id = self.get_model_chunk_id(0, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run warmup forward passes. for i in range(num_warmup_microbatch): + last_iteration = i == num_warmup_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - self.send_forward(model_chunk_id, output_obj) + + if last_iteration and num_microbatch_remaining == 0: + self.send_forward(model_chunk_id, output_obj) + else: + input_obj = self.send_forward_recv_forward( + model_chunk_id_send=model_chunk_id, + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), + output_tensor=output_obj, + send_prior=self.stage_manager.stage % 2 == 0, + ) if num_microbatch_remaining > 0: - model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(0, is_forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) last_iteration = i == num_microbatch_remaining - 1 + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + # Pop output_obj and output_obj from the start of the list for the backward pass. + _input_obj = input_objs[model_chunk_id].pop(0) + _output_obj = output_objs[model_chunk_id].pop(0) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + + # NOTE: perform 2x communication for forward and backward + def send_forward_recv_backward(): + if last_iteration and num_microbatch == num_microbatch_remaining: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + self.send_forward(model_chunk_id, output_obj) + else: + output_obj_grad = self.send_forward_recv_backward( + model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + output_tensor=output_obj, + send_prior_fallback=self.stage_manager.stage % 2 == 0, + ) + return output_obj_grad + + def send_backward_recv_forward(): + if last_iteration: + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + else: + input_obj = self.send_backward_recv_forward( + model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), + model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), + input_tensor_grad=input_obj_grad, + send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0, + ) + return input_obj + + if self.stage_manager.stage % 2 == 0: + output_obj_grad = send_forward_recv_backward() + input_obj = send_backward_recv_forward() + else: + input_obj = send_backward_recv_forward() + output_obj_grad = send_forward_recv_backward() + + if num_microbatch_remaining == 0: + model_chunk_id = self.get_model_chunk_id(0, is_forward=False) output_obj_grad = self.recv_backward(model_chunk_id) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - - # backward - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) - - if not last_iteration: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) - # Run cooldown backward passes. for i in range(num_microbatch_remaining, num_microbatch): + last_iteration = i == num_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) + _input_obj = input_objs[model_chunk_id].pop(0) + _output_obj = output_objs[model_chunk_id].pop(0) + # output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + + if not last_iteration: + output_obj_grad = self.send_backward_recv_backward( + model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + input_tensor_grad=input_obj_grad, + send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, + ) + else: + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + self.send_backward(model_chunk_id, input_obj_grad) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 6b2436d54..be60dcc74 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -54,10 +54,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -90,11 +90,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # NOTE: disable metadata cache when batch size changes (not valid anymore) if self.batch_size != self.last_batch_size: self.enable_metadata_cache = False - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None - + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None + self.last_batch_size = self.batch_size def load_micro_batch(self) -> Any: @@ -119,9 +119,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input tensor or input tensor list. """ if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor @@ -136,13 +136,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input gradient tensor or gradient tensor list. """ if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - def send_forward(self, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, output_tensor: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For 1F1B. @@ -151,10 +151,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = not self.enable_metadata_cache + self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + self.send_tensor_metadata = not self.enable_metadata_cache - def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -163,10 +163,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = not self.enable_metadata_cache + self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + self.send_grad_metadata = not self.enable_metadata_cache - def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + def send_forward_recv_backward( + self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None + ) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. @@ -175,19 +177,24 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): + if not self.send_tensor_metadata and self.grad_metadata_recv is not None: + send_prior_fallback = None # must not fallback output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, + output_tensor, next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.grad_metadata_recv, + send_prior_fallback=send_prior_fallback, ) - self.send_metadata_forward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + self.send_tensor_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: + def send_backward_recv_forward( + self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None + ) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. @@ -196,15 +203,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_first_stage(): + if not self.send_grad_metadata and self.tensor_metadata_recv is not None: + send_prior_fallback = None # must not fallback input_tensor = self.comm.send_backward_recv_forward( - output_object, + input_tensor_grad, prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, + send_metadata=self.send_grad_metadata, + metadata_recv=self.tensor_metadata_recv, + send_prior_fallback=send_prior_fallback, ) - self.send_metadata_backward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + self.send_grad_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor @@ -365,7 +375,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - output_obj_grad = self.send_forward_recv_backward(output_obj) + output_obj_grad = self.send_forward_recv_backward( + output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0 + ) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -379,7 +391,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if last_iteration: self.send_backward(input_obj_grad) else: - input_obj = self.send_backward_recv_forward(input_obj_grad) + input_obj = self.send_backward_recv_forward( + input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0 + ) # Run cooldown backward passes. for i in range(num_warmup_microbatches): diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 40b6ac8eb..1c859fd93 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -1,5 +1,3 @@ -import warnings - import pytest import torch import torch.distributed as dist @@ -33,7 +31,7 @@ def check_p2p_communication(): for obj in data: p2p.send_forward(obj) for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i]) + recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False) assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: @@ -48,7 +46,7 @@ def check_p2p_communication(): for obj in data: p2p.send_backward(obj) for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i]) + recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True) assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: @@ -59,7 +57,6 @@ def check_p2p_communication(): p2p.send_forward(data[-(i + 1)]) assert recv_obj == data[i] - warnings.filterwarnings("error") tensor_metadata = TensorMetadata( key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad )