suppport interleaved pp

pull/182/head
zhanglei 2023-08-16 12:02:59 +08:00
parent 7b4933de0d
commit 8cdd1abb35
2 changed files with 28 additions and 13 deletions

View File

@ -125,7 +125,7 @@ model = dict(
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_chunks=2, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False, sequence_parallel=False,
num_experts=4, num_experts=4,
moe_use_residual=True, moe_use_residual=True,
@ -144,7 +144,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
# zero1=8, # zero1=8,
pipeline=dict(size=4, interleaved_overlap=False), pipeline=dict(size=4, interleaved_overlap=True),
tensor=dict(size=2), tensor=dict(size=2),
) )

View File

@ -698,6 +698,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._input_objs = [[] for _ in range(num_chunks)] self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)] self._output_objs = [[] for _ in range(num_chunks)]
self._output_obj_grads = [[] for _ in range(num_chunks)] self._output_obj_grads = [[] for _ in range(num_chunks)]
self._moe_losses = [[] for _ in range(num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)]
@ -709,6 +710,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._input_objs = [[] for _ in range(self._num_chunks)] self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)] self._output_objs = [[] for _ in range(self._num_chunks)]
self._output_obj_grads = [[] for _ in range(self._num_chunks)] self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._moe_losses = [[] for _ in range(self._num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)]
@ -730,7 +732,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return move_to_device(micro_batch_data) return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id): def _forward_step(self, engine, chunk_id, moe_loss_coeff:float=1.0):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -752,7 +754,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data) self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model[chunk_id], data) output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage # Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -772,7 +774,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._accum_loss.add_(loss_reduced.detach()) self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced output_obj = loss_reduced
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
self._output_objs[chunk_id].append(output_obj) self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
return output_obj return output_obj
@ -798,8 +804,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
input_obj = self._input_objs[chunk_id].pop(0) input_obj = self._input_objs[chunk_id].pop(0)
output_obj = self._output_objs[chunk_id].pop(0) output_obj = self._output_objs[chunk_id].pop(0)
output_obj_grad = self._output_obj_grads[chunk_id].pop(0) output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
moe_loss = self._moe_losses[chunk_id].pop(0)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad) input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss)
return input_obj_grad return input_obj_grad
@ -831,6 +838,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
receive_extra_backward: bool = False, receive_extra_backward: bool = False,
forward_only: bool = False, forward_only: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the warm-up loop and prepare data for the 1F1B stage. Run the warm-up loop and prepare data for the 1F1B stage.
@ -868,12 +876,13 @@ class InterleavedPipelineScheduler(PipelineScheduler):
for k in range(num_warmup_microsteps): for k in range(num_warmup_microsteps):
chunk_id = self._get_chunk_by_microbatch(k) chunk_id = self._get_chunk_by_microbatch(k)
output_obj = self._forward_step(engine, chunk_id) output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
if forward_only: if forward_only:
# when forward-only, no need to save tensors for a backward pass # when forward-only, no need to save tensors for a backward pass
self._input_objs[chunk_id].pop() self._input_objs[chunk_id].pop()
self._output_objs[chunk_id].pop() self._output_objs[chunk_id].pop()
self._moe_losses[chunk_id].pop()
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
if isinstance(output_obj, torch.Tensor): if isinstance(output_obj, torch.Tensor):
@ -949,6 +958,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
num_1f1b_micropairs: int, num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False, all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the 1F1B loop with overlap. Run the 1F1B loop with overlap.
@ -978,7 +988,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True) backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
# 1. Forward pass. # 1. Forward pass.
output_obj = self._forward_step(engine, forward_chunk_id) output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# 2. Check if the backward input is ready. # 2. Check if the backward input is ready.
if backward_async_communicator is not None: if backward_async_communicator is not None:
@ -1063,6 +1073,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
num_1f1b_micropairs: int, num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False, all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the 1F1B loop without overlap. Run the 1F1B loop without overlap.
@ -1084,7 +1095,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# Forward pass. # Forward pass.
forward_microstep_id = k + num_warmup_microsteps forward_microstep_id = k + num_warmup_microsteps
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id) forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
output_obj = self._forward_step(engine, forward_chunk_id) output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# Backward pass. # Backward pass.
backward_microstep_id = k backward_microstep_id = k
@ -1189,7 +1200,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
) )
) )
def _forward_only_step(self, engine: Engine): def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
num_microsteps = self.num_microbatches * self._num_chunks num_microsteps = self.num_microbatches * self._num_chunks
num_warmup_microsteps = num_microsteps num_warmup_microsteps = num_microsteps
@ -1199,9 +1210,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps, num_warmup_microsteps,
receive_extra_backward=False, receive_extra_backward=False,
forward_only=True, forward_only=True,
moe_loss_coeff=moe_loss_coeff,
) )
def _forward_backward_step(self, engine: Engine): def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
all_warmup_microsteps = False all_warmup_microsteps = False
num_microsteps = self.num_microbatches * self._num_chunks num_microsteps = self.num_microbatches * self._num_chunks
@ -1235,6 +1247,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_microsteps, num_microsteps,
num_warmup_steps, num_warmup_steps,
receive_extra_backward=receive_extra_backward, receive_extra_backward=receive_extra_backward,
moe_loss_coeff=moe_loss_coeff,
) )
# 2. 1F1B # 2. 1F1B
@ -1243,12 +1256,14 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_steps, num_warmup_steps,
num_1f1b_micropairs=num_1f1b_micropairs, num_1f1b_micropairs=num_1f1b_micropairs,
all_warmup_microsteps=all_warmup_microsteps, all_warmup_microsteps=all_warmup_microsteps,
moe_loss_coeff=moe_loss_coeff,
) )
# 3. Cooldown # 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
return_output_label=True, moe_loss_coeff:float=1.0):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
@ -1276,9 +1291,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._return_tensors = [] self._return_tensors = []
if forward_only: if forward_only:
self._forward_only_step(engine) self._forward_only_step(engine, moe_loss_coeff)
else: else:
self._forward_backward_step(engine) self._forward_backward_step(engine, moe_loss_coeff)
if return_output_label and len(self._return_tensors) > 0: if return_output_label and len(self._return_tensors) > 0:
output, label = pack_return_tensors(self._return_tensors) output, label = pack_return_tensors(self._return_tensors)