[fix] fix mixtral policy;

pull/6077/head
duanjunwen 2024-10-08 09:34:09 +00:00
parent 292a504bea
commit cc500b3e25
1 changed files with 1 additions and 1 deletions

View File

@ -269,7 +269,7 @@ class MixtralPolicy(Policy):
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
stage_manager.is_last_stage(ignore_chunk=True)
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm