mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] 1f1b schedule receive microbatch size (#4589)
parent
38ccb8b1a3
commit
508ca36fe3
|
@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
|
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||||
|
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||||
|
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||||
|
@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
|
microbatch_size: Optional[int] = None,
|
||||||
initial_scale: float = 2**16,
|
initial_scale: float = 2**16,
|
||||||
min_scale: float = 1,
|
min_scale: float = 1,
|
||||||
growth_factor: float = 2,
|
growth_factor: float = 2,
|
||||||
|
@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
||||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
microbatch_size=microbatch_size)
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||||
|
|
|
@ -17,14 +17,26 @@ from .base import PipelineSchedule
|
||||||
|
|
||||||
class OneForwardOneBackwardSchedule(PipelineSchedule):
|
class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
|
|
||||||
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
|
def __init__(self,
|
||||||
|
stage_manager: PipelineStageManager,
|
||||||
|
num_microbatches: Optional[int] = None,
|
||||||
|
microbatch_size: Optional[int] = None) -> None:
|
||||||
|
"""1F1B pipeline schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stage_manager (PipelineStageManager): Pipeline stage manager
|
||||||
|
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
|
||||||
|
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
|
||||||
|
"""
|
||||||
super().__init__(stage_manager)
|
super().__init__(stage_manager)
|
||||||
|
assert num_microbatches is not None or microbatch_size is not None, \
|
||||||
|
"Either num_microbatches or microbatch_size should be provided"
|
||||||
self.comm = PipelineP2PCommunication(stage_manager)
|
self.comm = PipelineP2PCommunication(stage_manager)
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
|
self.microbatch_size = microbatch_size
|
||||||
self.batch: Optional[Any] = None
|
self.batch: Optional[Any] = None
|
||||||
self.batch_size: Optional[int] = None
|
self.batch_size: Optional[int] = None
|
||||||
self.microbatch_offset: Optional[int] = None
|
self.microbatch_offset: Optional[int] = None
|
||||||
self.microbatch_size: Optional[int] = None
|
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
assert self.batch_size % self.num_microbatches == 0, \
|
if self.num_microbatches is not None:
|
||||||
"Batch size should divided by the number of microbatches"
|
assert self.batch_size % self.num_microbatches == 0, \
|
||||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
"Batch size should divided by the number of microbatches"
|
||||||
|
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||||
|
else:
|
||||||
|
assert self.batch_size % self.microbatch_size == 0, \
|
||||||
|
"Batch size should divided by the microbatch size"
|
||||||
|
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||||
|
|
||||||
def load_micro_batch(self) -> Any:
|
def load_micro_batch(self) -> Any:
|
||||||
"""Load a micro batch from the current batch.
|
"""Load a micro batch from the current batch.
|
||||||
|
|
|
@ -61,7 +61,7 @@ def examine_pp():
|
||||||
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
||||||
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
|
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
|
||||||
|
|
||||||
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
||||||
if idx % (world_size) == local_rank:
|
if idx % (world_size) == local_rank:
|
||||||
|
|
Loading…
Reference in New Issue