diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 3c8b00977..23b3f4e6c 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch import torch.cuda @@ -22,6 +22,7 @@ class InterleavedSchedule(PipelineSchedule): num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, ) -> None: super().__init__(stage_manager) assert ( @@ -39,6 +40,7 @@ class InterleavedSchedule(PipelineSchedule): self.microbatch_offset: List[int] # P2PMeta cache + self.enable_metadata_cache = enable_metadata_cache self.send_metadata_forward = True self.send_metadata_backward = True self.metadata_recv_forward = None @@ -54,30 +56,33 @@ class InterleavedSchedule(PipelineSchedule): batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) + + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.batch = batch self.batch_size = get_batch_size(batch) - if self.last_batch_size is None: - self.last_batch_size = self.batch_size - else: - assert self.forward_only or self.last_batch_size == self.batch_size - # TODO: support arbitrary batch size when forward_only=True - self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - if self.num_microbatch is not None: + + if self.microbatch_size is None: assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" self.microbatch_size = self.batch_size // self.num_microbatch - elif self.microbatch_size is not None: + if self.num_microbatch is None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatch = self.batch_size // self.microbatch_size - else: - raise ValueError("Either num_microbatch or microbatch_size should be provided") - assert ( - self.num_microbatch % self.num_model_chunks == 0 - ), "Number of microbatch should be an integer multiple of number of model chunks" + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatch - assert ( - self.num_microbatch % self.stage_manager.num_stages == 0 - ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # NOTE: disable metadata cache when batch size changes (not valid anymore) + if self.batch_size != self.last_batch_size: + self.enable_metadata_cache = False + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + + self.last_batch_size = self.batch_size def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -88,6 +93,7 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: Micro batch. """ + assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) @@ -122,7 +128,7 @@ class InterleavedSchedule(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.metadata_recv_forward is None: + if self.enable_metadata_cache and self.metadata_recv_forward is None: self.metadata_recv_forward = create_fast_send_metadata(input_tensor) return input_tensor @@ -141,7 +147,7 @@ class InterleavedSchedule(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.metadata_recv_backward is None: + if self.enable_metadata_cache and self.metadata_recv_backward is None: self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad @@ -158,7 +164,7 @@ class InterleavedSchedule(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = False + self.send_metadata_forward = not self.enable_metadata_cache def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -172,7 +178,7 @@ class InterleavedSchedule(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = False + self.send_metadata_backward = not self.enable_metadata_cache def send_forward_recv_backward( self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None @@ -185,8 +191,8 @@ class InterleavedSchedule(PipelineSchedule): send_metadata=self.send_metadata_forward, metadata_recv=self.metadata_recv_backward, ) - self.send_metadata_forward = False - if self.metadata_recv_backward is None: + self.send_metadata_forward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_backward is None: self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad @@ -202,8 +208,8 @@ class InterleavedSchedule(PipelineSchedule): send_metadata=self.send_metadata_backward, metadata_recv=self.metadata_recv_forward, ) - self.send_metadata_backward = False - if self.metadata_recv_forward is None: + self.send_metadata_backward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_forward is None: self.metadata_recv_forward = create_fast_send_metadata(input_tensor) return input_tensor @@ -297,66 +303,74 @@ class InterleavedSchedule(PipelineSchedule): input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step( + def run_forward_only( self, model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, - ) -> dict: - """Runs interleaved schedule, with communication between pipeline stages. + ) -> Dict: + assert self.forward_only - Args: - model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification - data_iter (Iterable): Data iterator. - criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. - optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. - return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. - return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + self.load_batch(data_iter) - Returns: - dict: A dict with keys: 'loss' and 'outputs'. + outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None + + accum_loss = None + if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_current_device()) + + # Run warmup forward passes. + for i in range(self.num_microbatch * self.num_model_chunks): + model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def run_forward_backward( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: """ - self.forward_only = not torch.is_grad_enabled() - if optimizer is None: - assert self.forward_only, "Optimizer should be passed when doing backward." + Runs interleaved schedule, with communication between pipeline stages. + """ + assert not self.forward_only self.load_batch(data_iter) num_microbatch = self.num_microbatch * self.num_model_chunks - if self.forward_only: - num_warmup_microbatch = num_microbatch - else: - num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 - num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages - num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - + num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) num_microbatch_remaining = num_microbatch - num_warmup_microbatch # Input, output tensors only need to be saved when doing backward passes - input_objs = None - output_objs = None - - if not self.forward_only: - input_objs = [[] for _ in range(self.num_model_chunks)] - output_objs = [[] for _ in range(self.num_model_chunks)] + input_objs = [[] for _ in range(self.num_model_chunks)] + output_objs = [[] for _ in range(self.num_model_chunks)] outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None accum_loss = None if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.scalar_tensor(0, device=get_current_device()) # Run warmup forward passes. for i in range(num_warmup_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=True) input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if not self.forward_only: - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) self.send_forward(model_chunk_id, output_obj) if num_microbatch_remaining > 0: @@ -369,47 +383,72 @@ class InterleavedSchedule(PipelineSchedule): last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if self.forward_only: - if not last_iteration: - input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj) - else: - self.send_forward(model_chunk_id, output_obj) - - else: - self.send_forward(model_chunk_id, output_obj) - # Add input_obj and output_obj to end of list. - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) - model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) - # backward - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) - if not last_iteration: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + if not last_iteration: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run cooldown backward passes. - if not self.forward_only: - for i in range(num_microbatch_remaining, num_microbatch): - model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) + for i in range(num_microbatch_remaining, num_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) - if not self.forward_only: - assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward( + model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + return result diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 8c161efec..6b2436d54 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch import torch.cuda @@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): stage_manager: PipelineStageManager, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, ) -> None: """1F1B pipeline schedule. @@ -50,9 +51,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self._use_microbatch_size = num_microbatches is None # P2PMeta cache + self.enable_metadata_cache = enable_metadata_cache self.send_metadata_forward = True self.send_metadata_backward = True self.metadata_recv_forward = None @@ -69,29 +70,40 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if device is not None: batch = tree_map(partial(to_device, device=device), batch) + self.microbatch_offset = 0 self.batch = batch self.batch_size = get_batch_size(batch) - if self.last_batch_size is None: - self.last_batch_size = self.batch_size - else: - assert self.forward_only or self.last_batch_size == self.batch_size - # TODO: support arbitrary batch size when forward_only=True - self.microbatch_offset = 0 - if not self._use_microbatch_size: - assert ( - self.batch_size % self.num_microbatches == 0 - ), "Batch size should divided by the number of microbatches" + + if self.microbatch_size is None: + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches" self.microbatch_size = self.batch_size // self.num_microbatches - else: + if self.num_microbatches is None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatches = self.batch_size // self.microbatch_size + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatches + + if self.forward_only: + self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 + # NOTE: disable metadata cache when batch size changes (not valid anymore) + if self.batch_size != self.last_batch_size: + self.enable_metadata_cache = False + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + + self.last_batch_size = self.batch_size + def load_micro_batch(self) -> Any: """Load a micro batch from the current batch. Returns: Any: Micro batch. """ + assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) @@ -108,7 +120,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_first_stage(): input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.metadata_recv_forward is None: + if self.enable_metadata_cache and self.metadata_recv_forward is None: self.metadata_recv_forward = create_fast_send_metadata(input_tensor) return input_tensor @@ -125,7 +137,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.metadata_recv_backward is None: + if self.enable_metadata_cache and self.metadata_recv_backward is None: self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad @@ -140,7 +152,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = False + self.send_metadata_forward = not self.enable_metadata_cache def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -152,7 +164,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = False + self.send_metadata_backward = not self.enable_metadata_cache def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. @@ -169,8 +181,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): send_metadata=self.send_metadata_forward, metadata_recv=self.metadata_recv_backward, ) - self.send_metadata_forward = False - if self.metadata_recv_backward is None: + self.send_metadata_forward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_backward is None: self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad @@ -190,8 +202,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): send_metadata=self.send_metadata_backward, metadata_recv=self.metadata_recv_forward, ) - self.send_metadata_backward = False - if self.metadata_recv_forward is None: + self.send_metadata_backward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_forward is None: self.metadata_recv_forward = create_fast_send_metadata(input_tensor) return input_tensor @@ -274,32 +286,50 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step( + def run_forward_only( self, model: Module, data_iter: Iterable, criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, - ) -> dict: - """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + ) -> Dict: + """ + Runs forward only schedule, with communication between pipeline stages. + """ + assert self.forward_only - Args: - model (Module): Model to be trained. - data_iter (Iterable): Data iterator. - criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. - optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. - return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. - return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + self.load_batch(data_iter) - Returns: - dict: A dict with keys: 'loss' and 'outputs'. - """ + accum_loss = None + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.scalar_tensor(0, device=get_current_device()) + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None - self.forward_only = not torch.is_grad_enabled() - if optimizer is None: - assert self.forward_only, "Optimizer should be passed when doing backward." + for _ in range(self.num_microbatches): + input_obj = self.recv_forward() + output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) + self.send_forward(output_obj) + + if outputs is not None: + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + return {"loss": accum_loss, "outputs": outputs} + + def run_forward_backward( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + """ + Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + """ + assert not self.forward_only self.load_batch(data_iter) @@ -309,16 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches # Input, output tensors only need to be saved when doing backward passes - input_objs = None - output_objs = None - - if not self.forward_only: - input_objs = [] - output_objs = [] + input_objs, output_objs = [], [] accum_loss = None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.scalar_tensor(0, device=get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None # Run warmup forward passes. @@ -326,10 +351,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj = self.recv_forward() output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) self.send_forward(output_obj) - - if not self.forward_only: - input_objs.append(input_obj) - output_objs.append(output_obj) + input_objs.append(input_obj) + output_objs.append(output_obj) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -342,45 +365,68 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - - if self.forward_only: - self.send_forward(output_obj) - - if not last_iteration: - input_obj = self.recv_forward() - + output_obj_grad = self.send_forward_recv_backward(output_obj) + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + self.send_backward(input_obj_grad) else: - output_obj_grad = self.send_forward_recv_backward(output_obj) - # Add input_obj and output_obj to end of list. - input_objs.append(input_obj) - output_objs.append(output_obj) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - - if last_iteration: - self.send_backward(input_obj_grad) - else: - input_obj = self.send_backward_recv_forward(input_obj_grad) + input_obj = self.send_backward_recv_forward(input_obj_grad) # Run cooldown backward passes. - if not self.forward_only: - for i in range(num_warmup_microbatches): - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) - output_obj_grad = self.recv_backward() - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(input_obj_grad) + output_obj_grad = self.recv_backward() + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(input_obj_grad) - if not self.forward_only: - assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: if isinstance(model, ModelWrapper): model = model.unwrap() outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model (Module): Model to be trained. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: Dictionary containing loss and outputs. + """ + + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs) + + return result diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index 31c6937ee..a379b906a 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -88,24 +88,21 @@ class GLUEDataBuilder: ) def val_dataloader(self): - # TODO: drop_last is set to True for now to avoid error when using PP # as the last batch may not be divisible by the number of microbatches if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader( - self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True - ) + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits ] def test_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True) + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits ] diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index 394ff831b..fc4eacf6f 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -1,8 +1,17 @@ #!/bin/bash -set -xe +set -x pip install -r requirements.txt +FAIL_LIMIT=3 + for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do - torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" + for i in $(seq 1 $FAIL_LIMIT); do + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break + echo "Failed $i times" + if [ $i -eq $FAIL_LIMIT ]; then + echo "Failed $FAIL_LIMIT times, exiting" + exit 1 + fi + done done