mirror of https://github.com/hpcaitech/ColossalAI
parent
eb4f2d90f9
commit
084c91246c
|
@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
@ -232,10 +232,12 @@ def main() -> None:
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||||
# Freeze part of parameters.
|
# Freeze part of parameters.
|
||||||
if args.freeze_non_embeds_params:
|
if args.freeze_non_embeds_params:
|
||||||
freeze_non_embeds_parameters(model=model)
|
freeze_non_embeds_parameters(model=model)
|
||||||
|
# this is essential, otherwise the grad checkpoint will not work.
|
||||||
|
model.train()
|
||||||
|
|
||||||
if args.use_grad_checkpoint:
|
if args.use_grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
@ -277,8 +279,6 @@ def main() -> None:
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
dataloader=dataloader,
|
dataloader=dataloader,
|
||||||
)
|
)
|
||||||
if args.load_checkpoint is None:
|
|
||||||
booster.load_model(model, args.pretrained)
|
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue