|
|
|
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
|
|
|
|
|
next_decoder_cache = None
|
|
|
|
|
|
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
|
|
|
num_ckpt_layers = 0
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
num_ckpt_layers = end_idx - start_idx
|
|
|
|
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
|
|
|
|
if shard_config.gradient_checkpoint_config is not None:
|
|
|
|
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
|
|
|
|
stage=stage_manager.stage,
|
|
|
|
|
num_stages=stage_manager.num_stages,
|
|
|
|
|
num_layers=end_idx - start_idx,
|
|
|
|
|
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
|
|
|
|
num_model_chunks=stage_manager.num_model_chunks,
|
|
|
|
|
)
|
|
|
|
|
assert num_ckpt_layers <= end_idx - start_idx
|
|
|
|
|
|
|
|
|
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
if idx - start_idx < num_ckpt_layers:
|
|
|
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
|
|
|
decoder_layer.__call__,
|
|
|
|
|
hidden_states,
|
|
|
|
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
|
|
|
|
|
|