Browse Source

[pipeline]: support arbitrary batch size in forward_only mode (#5201)

* fix: remove drop last in val & test dataloader

* feat: add run_forward_only, support arbitrary bs

* chore: modify ci script
pull/5190/head
Wenhao Chen 11 months ago committed by Xuanlei Zhao
parent
commit
931d0e0731
  1. 221
      colossalai/pipeline/schedule/interleaved_pp.py
  2. 206
      colossalai/pipeline/schedule/one_f_one_b.py
  3. 11
      examples/language/bert/data.py
  4. 13
      examples/language/bert/test_ci.sh

221
colossalai/pipeline/schedule/interleaved_pp.py

@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
@ -22,6 +22,7 @@ class InterleavedSchedule(PipelineSchedule):
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
@ -39,6 +40,7 @@ class InterleavedSchedule(PipelineSchedule):
self.microbatch_offset: List[int]
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
@ -54,30 +56,33 @@ class InterleavedSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None:
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
elif self.microbatch_size is not None:
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
else:
raise ValueError("Either num_microbatch or microbatch_size should be provided")
assert (
self.num_microbatch % self.num_model_chunks == 0
), "Number of microbatch should be an integer multiple of number of model chunks"
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
@ -88,6 +93,7 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
@ -122,7 +128,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
@ -141,7 +147,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -158,7 +164,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
self.send_metadata_forward = not self.enable_metadata_cache
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
@ -172,7 +178,7 @@ class InterleavedSchedule(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
self.send_metadata_backward = not self.enable_metadata_cache
def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
@ -185,8 +191,8 @@ class InterleavedSchedule(PipelineSchedule):
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.send_metadata_forward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -202,8 +208,8 @@ class InterleavedSchedule(PipelineSchedule):
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.send_metadata_backward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
@ -297,66 +303,74 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs interleaved schedule, with communication between pipeline stages.
) -> Dict:
assert self.forward_only
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
self.load_batch(data_iter)
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
# Run warmup forward passes.
for i in range(self.num_microbatch * self.num_model_chunks):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
Runs interleaved schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
if self.forward_only:
num_warmup_microbatch = num_microbatch
else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.scalar_tensor(0, device=get_current_device())
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not self.forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatch_remaining > 0:
@ -369,47 +383,72 @@ class InterleavedSchedule(PipelineSchedule):
last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if self.forward_only:
if not last_iteration:
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run cooldown backward passes.
if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
return result

206
colossalai/pipeline/schedule/one_f_one_b.py

@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
"""1F1B pipeline schedule.
@ -50,9 +51,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
@ -69,29 +70,40 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
@ -108,7 +120,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
@ -125,7 +137,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -140,7 +152,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
self.send_metadata_forward = not self.enable_metadata_cache
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
@ -152,7 +164,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
self.send_metadata_backward = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
@ -169,8 +181,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.send_metadata_forward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
@ -190,8 +202,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.send_metadata_backward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
@ -274,32 +286,50 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(
def run_forward_only(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
) -> Dict:
"""
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
self.load_batch(data_iter)
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
@ -309,16 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not self.forward_only:
input_objs = []
output_objs = []
input_objs, output_objs = [], []
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
@ -326,10 +351,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if not self.forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@ -342,45 +365,68 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if self.forward_only:
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
self.send_backward(input_obj_grad)
else:
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(input_obj_grad)
input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes.
if not self.forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result

11
examples/language/bert/data.py

@ -88,24 +88,21 @@ class GLUEDataBuilder:
)
def val_dataloader(self):
# TODO: drop_last is set to True for now to avoid error when using PP
# as the last batch may not be divisible by the number of microbatches
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(
self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True
)
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

13
examples/language/bert/test_ci.sh

@ -1,8 +1,17 @@
#!/bin/bash
set -xe
set -x
pip install -r requirements.txt
FAIL_LIMIT=3
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
for i in $(seq 1 $FAIL_LIMIT); do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
echo "Failed $i times"
if [ $i -eq $FAIL_LIMIT ]; then
echo "Failed $FAIL_LIMIT times, exiting"
exit 1
fi
done
done

Loading…
Cancel
Save