|
|
|
@ -17,14 +17,26 @@ from .base import PipelineSchedule
|
|
|
|
|
|
|
|
|
|
class OneForwardOneBackwardSchedule(PipelineSchedule): |
|
|
|
|
|
|
|
|
|
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: |
|
|
|
|
def __init__(self, |
|
|
|
|
stage_manager: PipelineStageManager, |
|
|
|
|
num_microbatches: Optional[int] = None, |
|
|
|
|
microbatch_size: Optional[int] = None) -> None: |
|
|
|
|
"""1F1B pipeline schedule. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
stage_manager (PipelineStageManager): Pipeline stage manager |
|
|
|
|
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. |
|
|
|
|
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
super().__init__(stage_manager) |
|
|
|
|
assert num_microbatches is not None or microbatch_size is not None, \ |
|
|
|
|
"Either num_microbatches or microbatch_size should be provided" |
|
|
|
|
self.comm = PipelineP2PCommunication(stage_manager) |
|
|
|
|
self.num_microbatches = num_microbatches |
|
|
|
|
self.microbatch_size = microbatch_size |
|
|
|
|
self.batch: Optional[Any] = None |
|
|
|
|
self.batch_size: Optional[int] = None |
|
|
|
|
self.microbatch_offset: Optional[int] = None |
|
|
|
|
self.microbatch_size: Optional[int] = None |
|
|
|
|
|
|
|
|
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: |
|
|
|
|
"""Load a batch from data iterator. |
|
|
|
@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|
|
|
|
self.batch = batch |
|
|
|
|
self.batch_size = get_batch_size(batch) |
|
|
|
|
self.microbatch_offset = 0 |
|
|
|
|
assert self.batch_size % self.num_microbatches == 0, \ |
|
|
|
|
"Batch size should divided by the number of microbatches" |
|
|
|
|
self.microbatch_size = self.batch_size // self.num_microbatches |
|
|
|
|
if self.num_microbatches is not None: |
|
|
|
|
assert self.batch_size % self.num_microbatches == 0, \ |
|
|
|
|
"Batch size should divided by the number of microbatches" |
|
|
|
|
self.microbatch_size = self.batch_size // self.num_microbatches |
|
|
|
|
else: |
|
|
|
|
assert self.batch_size % self.microbatch_size == 0, \ |
|
|
|
|
"Batch size should divided by the microbatch size" |
|
|
|
|
self.num_microbatches = self.batch_size // self.microbatch_size |
|
|
|
|
|
|
|
|
|
def load_micro_batch(self) -> Any: |
|
|
|
|
"""Load a micro batch from the current batch. |
|
|
|
|