diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 60bc8b711..5855dcc4f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -150,10 +150,8 @@ class LlamaPipelineForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, - num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, - model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), - num_model_chunks=stage_manager.num_model_chunks, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, ) assert num_ckpt_layers <= end_idx - start_idx