fix llama pretrain (#5287)

pull/5294/head
flybird11111 2024-01-19 17:49:02 +08:00 committed by GitHub
parent 6a56967855
commit f7e3f82a7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 2 deletions

View File

@ -273,11 +273,10 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(
range(step_nums),
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not print_flag,
total=num_steps_per_epoch,