[pipeline]: support arbitrary batch size in forward_only mode (#5201)

* fix: remove drop last in val & test dataloader

* feat: add run_forward_only, support arbitrary bs

* chore: modify ci script
pull/5190/head
Wenhao Chen 2024-01-02 23:41:12 +08:00 committed by Xuanlei Zhao
parent 1810b9100f
commit 931d0e0731
4 changed files with 293 additions and 202 deletions

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
@ -22,6 +22,7 @@ class InterleavedSchedule(PipelineSchedule):
num_model_chunks: int, num_model_chunks: int,
num_microbatch: Optional[int] = None, num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None: ) -> None:
super().__init__(stage_manager) super().__init__(stage_manager)
assert ( assert (
@ -39,6 +40,7 @@ class InterleavedSchedule(PipelineSchedule):
self.microbatch_offset: List[int] self.microbatch_offset: List[int]
# P2PMeta cache # P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True self.send_metadata_forward = True
self.send_metadata_backward = True self.send_metadata_backward = True
self.metadata_recv_forward = None self.metadata_recv_forward = None
@ -54,30 +56,33 @@ class InterleavedSchedule(PipelineSchedule):
batch = next(data_iter) batch = next(data_iter)
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size if self.microbatch_size is None:
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch self.microbatch_size = self.batch_size // self.num_microbatch
elif self.microbatch_size is not None: if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size self.num_microbatch = self.batch_size // self.microbatch_size
else:
raise ValueError("Either num_microbatch or microbatch_size should be provided")
assert ( if not self.forward_only:
self.num_microbatch % self.num_model_chunks == 0 assert self.last_batch_size is None or self.last_batch_size == self.batch_size
), "Number of microbatch should be an integer multiple of number of model chunks" assert self.batch_size == self.microbatch_size * self.num_microbatch
assert ( if self.forward_only:
self.num_microbatch % self.stage_manager.num_stages == 0 self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" # NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any: def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
@ -88,6 +93,7 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: Micro batch. Any: Micro batch.
""" """
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
@ -122,7 +128,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None: if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -141,7 +147,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None: if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -158,7 +164,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False self.send_metadata_forward = not self.enable_metadata_cache
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
@ -172,7 +178,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False self.send_metadata_backward = not self.enable_metadata_cache
def send_forward_recv_backward( def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
@ -185,8 +191,8 @@ class InterleavedSchedule(PipelineSchedule):
send_metadata=self.send_metadata_forward, send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward, metadata_recv=self.metadata_recv_backward,
) )
self.send_metadata_forward = False self.send_metadata_forward = not self.enable_metadata_cache
if self.metadata_recv_backward is None: if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -202,8 +208,8 @@ class InterleavedSchedule(PipelineSchedule):
send_metadata=self.send_metadata_backward, send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward, metadata_recv=self.metadata_recv_forward,
) )
self.send_metadata_backward = False self.send_metadata_backward = not self.enable_metadata_cache
if self.metadata_recv_forward is None: if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -297,7 +303,36 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad input_obj_grad[k] = v.grad
return input_obj_grad return input_obj_grad
def forward_backward_step( def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
self.load_batch(data_iter)
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
# Run warmup forward passes.
for i in range(self.num_microbatch * self.num_model_chunks):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self, self,
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
data_iter: Iterable, data_iter: Iterable,
@ -305,41 +340,21 @@ class InterleavedSchedule(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None, optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> Dict:
"""Runs interleaved schedule, with communication between pipeline stages.
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
""" """
self.forward_only = not torch.is_grad_enabled() Runs interleaved schedule, with communication between pipeline stages.
if optimizer is None: """
assert self.forward_only, "Optimizer should be passed when doing backward." assert not self.forward_only
self.load_batch(data_iter) self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks num_microbatch = self.num_microbatch * self.num_model_chunks
if self.forward_only:
num_warmup_microbatch = num_microbatch
else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch num_microbatch_remaining = num_microbatch - num_warmup_microbatch
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)] input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)]
@ -347,14 +362,13 @@ class InterleavedSchedule(PipelineSchedule):
accum_loss = None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.scalar_tensor(0, device=get_current_device())
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatch): for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True) model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not self.forward_only:
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
@ -369,13 +383,6 @@ class InterleavedSchedule(PipelineSchedule):
last_iteration = i == num_microbatch_remaining - 1 last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if self.forward_only:
if not last_iteration:
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
@ -398,7 +405,6 @@ class InterleavedSchedule(PipelineSchedule):
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
# Run cooldown backward passes. # Run cooldown backward passes.
if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch): for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0) input_obj = input_objs[model_chunk_id].pop(0)
@ -407,9 +413,42 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
return result

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
stage_manager: PipelineStageManager, stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None: ) -> None:
"""1F1B pipeline schedule. """1F1B pipeline schedule.
@ -50,9 +51,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache # P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True self.send_metadata_forward = True
self.send_metadata_backward = True self.send_metadata_backward = True
self.metadata_recv_forward = None self.metadata_recv_forward = None
@ -69,29 +70,40 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size if self.microbatch_size is None:
else: assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches self.microbatch_size = self.batch_size // self.num_microbatches
else: if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any: def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
Returns: Returns:
Any: Micro batch. Any: Micro batch.
""" """
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
@ -108,7 +120,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None: if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -125,7 +137,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None: if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -140,7 +152,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False self.send_metadata_forward = not self.enable_metadata_cache
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
@ -152,7 +164,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False self.send_metadata_backward = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
@ -169,8 +181,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
send_metadata=self.send_metadata_forward, send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward, metadata_recv=self.metadata_recv_backward,
) )
self.send_metadata_forward = False self.send_metadata_forward = not self.enable_metadata_cache
if self.metadata_recv_backward is None: if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -190,8 +202,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
send_metadata=self.send_metadata_backward, send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward, metadata_recv=self.metadata_recv_forward,
) )
self.send_metadata_backward = False self.send_metadata_backward = not self.enable_metadata_cache
if self.metadata_recv_forward is None: if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -274,7 +286,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad input_obj_grad[k] = v.grad
return input_obj_grad return input_obj_grad
def forward_backward_step( def run_forward_only(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
self.load_batch(data_iter)
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self, self,
model: Module, model: Module,
data_iter: Iterable, data_iter: Iterable,
@ -282,24 +325,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None, optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> Dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
""" """
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
self.forward_only = not torch.is_grad_enabled() """
if optimizer is None: assert not self.forward_only
assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
@ -309,16 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_objs = None input_objs, output_objs = [], []
output_objs = None
if not self.forward_only:
input_objs = []
output_objs = []
accum_loss = None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes. # Run warmup forward passes.
@ -326,8 +351,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = self.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj) self.send_forward(output_obj)
if not self.forward_only:
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
@ -342,14 +365,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if self.forward_only:
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
else:
output_obj_grad = self.send_forward_recv_backward(output_obj) output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
@ -367,7 +382,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = self.send_backward_recv_forward(input_obj_grad) input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not self.forward_only:
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
@ -376,7 +390,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
@ -384,3 +397,36 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
model = model.unwrap() model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result

View File

@ -88,24 +88,21 @@ class GLUEDataBuilder:
) )
def val_dataloader(self): def val_dataloader(self):
# TODO: drop_last is set to True for now to avoid error when using PP
# as the last batch may not be divisible by the number of microbatches # as the last batch may not be divisible by the number of microbatches
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader( return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True
)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits for x in self.eval_splits
] ]
def test_dataloader(self): def test_dataloader(self):
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True) return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits for x in self.eval_splits
] ]

View File

@ -1,8 +1,17 @@
#!/bin/bash #!/bin/bash
set -xe set -x
pip install -r requirements.txt pip install -r requirements.txt
FAIL_LIMIT=3
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" for i in $(seq 1 $FAIL_LIMIT); do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
echo "Failed $i times"
if [ $i -eq $FAIL_LIMIT ]; then
echo "Failed $FAIL_LIMIT times, exiting"
exit 1
fi
done
done done