|
|
@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters |
|
|
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune |
|
|
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from tqdm import tqdm |
|
|
|
from tqdm import tqdm |
|
|
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
|
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer |
|
|
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
import colossalai |
|
|
|
from colossalai.accelerator import get_accelerator |
|
|
|
from colossalai.accelerator import get_accelerator |
|
|
@ -232,7 +232,7 @@ def main() -> None: |
|
|
|
else nullcontext() |
|
|
|
else nullcontext() |
|
|
|
) |
|
|
|
) |
|
|
|
with init_ctx: |
|
|
|
with init_ctx: |
|
|
|
model = LlamaForCausalLM.from_pretrained(args.pretrained) |
|
|
|
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) |
|
|
|
# Freeze part of parameters. |
|
|
|
# Freeze part of parameters. |
|
|
|
if args.freeze_non_embeds_params: |
|
|
|
if args.freeze_non_embeds_params: |
|
|
|
freeze_non_embeds_parameters(model=model) |
|
|
|
freeze_non_embeds_parameters(model=model) |
|
|
@ -277,6 +277,8 @@ def main() -> None: |
|
|
|
lr_scheduler=lr_scheduler, |
|
|
|
lr_scheduler=lr_scheduler, |
|
|
|
dataloader=dataloader, |
|
|
|
dataloader=dataloader, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
if args.load_checkpoint is None: |
|
|
|
|
|
|
|
booster.load_model(model, args.pretrained) |
|
|
|
|
|
|
|
|
|
|
|
torch.set_default_dtype(torch.float) |
|
|
|
torch.set_default_dtype(torch.float) |
|
|
|
|
|
|
|
|
|
|
@ -329,7 +331,12 @@ def main() -> None: |
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, args.num_epochs): |
|
|
|
for epoch in range(start_epoch, args.num_epochs): |
|
|
|
dataloader.sampler.set_epoch(epoch=epoch) |
|
|
|
dataloader.sampler.set_epoch(epoch=epoch) |
|
|
|
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps) |
|
|
|
pbar = tqdm( |
|
|
|
|
|
|
|
desc=f"Epoch {epoch}", |
|
|
|
|
|
|
|
disable=not coordinator.is_master(), |
|
|
|
|
|
|
|
total=num_steps_per_epoch, |
|
|
|
|
|
|
|
initial=start_step // args.accumulation_steps, |
|
|
|
|
|
|
|
) |
|
|
|
total_loss = torch.tensor(0.0, device=get_current_device()) |
|
|
|
total_loss = torch.tensor(0.0, device=get_current_device()) |
|
|
|
for step, batch in enumerate(dataloader, start=start_step): |
|
|
|
for step, batch in enumerate(dataloader, start=start_step): |
|
|
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} |
|
|
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} |
|
|
@ -369,6 +376,7 @@ def main() -> None: |
|
|
|
coordinator.print_on_master("Deactivate NEFTune before saving model.") |
|
|
|
coordinator.print_on_master("Deactivate NEFTune before saving model.") |
|
|
|
deactivate_neftune(model, handle) |
|
|
|
deactivate_neftune(model, handle) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.empty_cache() |
|
|
|
save_checkpoint( |
|
|
|
save_checkpoint( |
|
|
|
save_dir=args.save_dir, |
|
|
|
save_dir=args.save_dir, |
|
|
|
booster=booster, |
|
|
|
booster=booster, |
|
|
|