[llama] fix neftune & pbar with start_step (#5364)

pull/5366/head
Camille Zhong 2024-02-05 18:04:23 +08:00 committed by GitHub
parent a4cec1715b
commit 44ca61a22b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -17,7 +17,7 @@ import torch
def unwrap(model):
if hasattr(model, "module"):
return unwrap_model(model.module)
return model.unwrap()
else:
return model

View File

@ -329,9 +329,9 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader):
for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)