mirror of https://github.com/hpcaitech/ColossalAI
[llama] fix neftune & pbar with start_step (#5364)
parent
a4cec1715b
commit
44ca61a22b
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
def unwrap(model):
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
return model.unwrap()
|
||||
else:
|
||||
return model
|
||||
|
||||
|
|
|
@ -329,9 +329,9 @@ def main() -> None:
|
|||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(dataloader):
|
||||
for step, batch in enumerate(dataloader, start=start_step):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
|
Loading…
Reference in New Issue