[Hotfix] Fix OPT gradient checkpointing forward

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5882/head
Edenzzzz 2024-07-03 14:57:57 +08:00 committed by GitHub
parent ea94c07b95
commit eb24fcd914
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,