[fix] MixtralForCausalLMPolicy get_held_layer support zbv;

pull/6083/head
duanjunwen 2024-10-10 05:40:22 +00:00
parent 3f5bec8dc4
commit 9ee80fc828
3 changed files with 18 additions and 3 deletions

View File

@ -268,6 +268,7 @@ class MixtralPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:

View File

@ -343,8 +343,18 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
# if stage_manager.is_last_stage():
# held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -167,6 +167,7 @@ def main():
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
num_microbatches=args.batch_size // args.mbs,
precision="bf16",
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
@ -208,8 +209,10 @@ def main():
with init_ctx:
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
# if args.grad_checkpoint:
# model.gradient_checkpointing_enable()
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
@ -224,6 +227,7 @@ def main():
)
optimizer = HybridAdam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)