Browse Source

fix llama (#5856)

dev/zero-offload
Wang Binluo 5 months ago committed by GitHub
parent
commit
de3f67d128
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      colossalai/shardformer/modeling/llama.py

4
colossalai/shardformer/modeling/llama.py

@ -150,10 +150,8 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx

Loading…
Cancel
Save