dev/zero-offload
Wang Binluo 2024-06-26 10:15:13 +08:00 committed by GitHub
parent 4c06215dce
commit de3f67d128
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 3 deletions

View File

@ -150,10 +150,8 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx