Browse Source

[llama] fix memory issue (#5371)

* [llama] fix memory issue

* [llama] add comment
pull/5377/head
Hongxin Liu 10 months ago committed by GitHub
parent
commit
084c91246c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      applications/Colossal-LLaMA-2/train.py

8
applications/Colossal-LLaMA-2/train.py

@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
@ -232,10 +232,12 @@ def main() -> None:
else nullcontext()
)
with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
model = LlamaForCausalLM.from_pretrained(args.pretrained)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
@ -277,8 +279,6 @@ def main() -> None:
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
if args.load_checkpoint is None:
booster.load_model(model, args.pretrained)
torch.set_default_dtype(torch.float)

Loading…
Cancel
Save