From 44ca61a22b5265c56f279c45691dfae37bc001bd Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:04:23 +0800 Subject: [PATCH] [llama] fix neftune & pbar with start_step (#5364) --- .../Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py | 2 +- applications/Colossal-LLaMA-2/train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 9f6c9c1cc..21d769f3c 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -17,7 +17,7 @@ import torch def unwrap(model): if hasattr(model, "module"): - return unwrap_model(model.module) + return model.unwrap() else: return model diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 314af923c..ebb919979 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -329,9 +329,9 @@ def main() -> None: for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) + pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps) total_loss = torch.tensor(0.0, device=get_current_device()) - for step, batch in enumerate(dataloader): + for step, batch in enumerate(dataloader, start=start_step): batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch_output = model(**batch)