|
|
|
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
|
|
|
|
|
next_decoder_cache = None |
|
|
|
|
|
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1] |
|
|
|
|
num_ckpt_layers = 0 |
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
num_ckpt_layers = end_idx - start_idx |
|
|
|
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer |
|
|
|
|
if shard_config.gradient_checkpoint_config is not None: |
|
|
|
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( |
|
|
|
|
stage=stage_manager.stage, |
|
|
|
|
num_stages=stage_manager.num_stages, |
|
|
|
|
num_layers=end_idx - start_idx, |
|
|
|
|
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), |
|
|
|
|
num_model_chunks=stage_manager.num_model_chunks, |
|
|
|
|
) |
|
|
|
|
assert num_ckpt_layers <= end_idx - start_idx |
|
|
|
|
|
|
|
|
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): |
|
|
|
|
if output_hidden_states: |
|
|
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
if idx - start_idx < num_ckpt_layers: |
|
|
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
|
|
decoder_layer.__call__, |
|
|
|
|
hidden_states, |
|
|
|
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
|
|
|
|
|
|
|
|
|
|
if use_cache: |
|
|
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
|
|
|
|
|
|
|
if output_attentions: |
|
|
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
|
|
|
|