suppport interleaved pp

pull/182/head
zhanglei 2023-08-16 12:02:59 +08:00
parent 7b4933de0d
commit 8cdd1abb35
2 changed files with 28 additions and 13 deletions

View File

@ -125,7 +125,7 @@ model = dict(
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_chunks=2, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
num_experts=4,
moe_use_residual=True,
@ -144,7 +144,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
# zero1=8,
pipeline=dict(size=4, interleaved_overlap=False),
pipeline=dict(size=4, interleaved_overlap=True),
tensor=dict(size=2),
)

View File

@ -698,6 +698,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)]
self._output_obj_grads = [[] for _ in range(num_chunks)]
self._moe_losses = [[] for _ in range(num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)]
@ -709,6 +710,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)]
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._moe_losses = [[] for _ in range(self._num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
@ -730,7 +732,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id):
def _forward_step(self, engine, chunk_id, moe_loss_coeff:float=1.0):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -752,7 +754,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model[chunk_id], data)
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -772,7 +774,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
return output_obj
@ -798,8 +804,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
input_obj = self._input_objs[chunk_id].pop(0)
output_obj = self._output_objs[chunk_id].pop(0)
output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
moe_loss = self._moe_losses[chunk_id].pop(0)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss)
return input_obj_grad
@ -831,6 +838,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
receive_extra_backward: bool = False,
forward_only: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the warm-up loop and prepare data for the 1F1B stage.
@ -868,12 +876,13 @@ class InterleavedPipelineScheduler(PipelineScheduler):
for k in range(num_warmup_microsteps):
chunk_id = self._get_chunk_by_microbatch(k)
output_obj = self._forward_step(engine, chunk_id)
output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
if forward_only:
# when forward-only, no need to save tensors for a backward pass
self._input_objs[chunk_id].pop()
self._output_objs[chunk_id].pop()
self._moe_losses[chunk_id].pop()
if not gpc.is_pipeline_last_stage():
if isinstance(output_obj, torch.Tensor):
@ -949,6 +958,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the 1F1B loop with overlap.
@ -978,7 +988,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
# 1. Forward pass.
output_obj = self._forward_step(engine, forward_chunk_id)
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# 2. Check if the backward input is ready.
if backward_async_communicator is not None:
@ -1063,6 +1073,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the 1F1B loop without overlap.
@ -1084,7 +1095,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# Forward pass.
forward_microstep_id = k + num_warmup_microsteps
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
output_obj = self._forward_step(engine, forward_chunk_id)
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# Backward pass.
backward_microstep_id = k
@ -1189,7 +1200,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
)
)
def _forward_only_step(self, engine: Engine):
def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
num_microsteps = self.num_microbatches * self._num_chunks
num_warmup_microsteps = num_microsteps
@ -1199,9 +1210,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps,
receive_extra_backward=False,
forward_only=True,
moe_loss_coeff=moe_loss_coeff,
)
def _forward_backward_step(self, engine: Engine):
def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
# Compute number of warmup and remaining microbatches.
all_warmup_microsteps = False
num_microsteps = self.num_microbatches * self._num_chunks
@ -1235,6 +1247,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_microsteps,
num_warmup_steps,
receive_extra_backward=receive_extra_backward,
moe_loss_coeff=moe_loss_coeff,
)
# 2. 1F1B
@ -1243,12 +1256,14 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_steps,
num_1f1b_micropairs=num_1f1b_micropairs,
all_warmup_microsteps=all_warmup_microsteps,
moe_loss_coeff=moe_loss_coeff,
)
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
return_output_label=True, moe_loss_coeff:float=1.0):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
@ -1276,9 +1291,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._return_tensors = []
if forward_only:
self._forward_only_step(engine)
self._forward_only_step(engine, moe_loss_coeff)
else:
self._forward_backward_step(engine)
self._forward_backward_step(engine, moe_loss_coeff)
if return_output_label and len(self._return_tensors) > 0:
output, label = pack_return_tensors(self._return_tensors)