diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index f10860fef..b250b4976 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -221,7 +221,7 @@ class OPTPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - layer_outputs = self._gradient_checkpointing_func( + layer_outputs = self.decoder._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask,