mirror of https://github.com/hpcaitech/ColossalAI
fix llama (#5856)
parent
4c06215dce
commit
de3f67d128
|
@ -150,10 +150,8 @@ class LlamaPipelineForwards:
|
|||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
|
|
Loading…
Reference in New Issue