[llama] fix memory issue (#5371)

* [llama] fix memory issue

* [llama] add comment
pull/5377/head
Hongxin Liu 2024-02-06 19:02:37 +08:00 committed by GitHub
parent eb4f2d90f9
commit 084c91246c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -232,10 +232,12 @@ def main() -> None:
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) model = LlamaForCausalLM.from_pretrained(args.pretrained)
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint: if args.use_grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@ -277,8 +279,6 @@ def main() -> None:
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=dataloader, dataloader=dataloader,
) )
if args.load_checkpoint is None:
booster.load_model(model, args.pretrained)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)