mirror of https://github.com/hpcaitech/ColossalAI
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
parent
3f5bec8dc4
commit
9ee80fc828
|
@ -268,6 +268,7 @@ class MixtralPipelineForwards:
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
|
|
|
@ -343,8 +343,18 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
# if stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -167,6 +167,7 @@ def main():
|
|||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
num_microbatches=args.batch_size // args.mbs,
|
||||
precision="bf16",
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
|
@ -208,8 +209,10 @@ def main():
|
|||
with init_ctx:
|
||||
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
||||
|
||||
# if args.grad_checkpoint:
|
||||
# model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
@ -224,6 +227,7 @@ def main():
|
|||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
|
|
Loading…
Reference in New Issue