mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] 1f1b schedule receive microbatch size (#4589)
parent
38ccb8b1a3
commit
508ca36fe3
|
@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
|
@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
|
@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
||||
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
|
|
|
@ -17,14 +17,26 @@ from .base import PipelineSchedule
|
|||
|
||||
class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
|
||||
def __init__(self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
Args:
|
||||
stage_manager (PipelineStageManager): Pipeline stage manager
|
||||
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
|
||||
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
|
||||
"""
|
||||
super().__init__(stage_manager)
|
||||
assert num_microbatches is not None or microbatch_size is not None, \
|
||||
"Either num_microbatches or microbatch_size should be provided"
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self.microbatch_size: Optional[int] = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
if self.num_microbatches is not None:
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
else:
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
"Batch size should divided by the microbatch size"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
|
|
@ -61,7 +61,7 @@ def examine_pp():
|
|||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
|
||||
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
|
||||
|
||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
||||
if idx % (world_size) == local_rank:
|
||||
|
|
Loading…
Reference in New Issue