fix inference rebatching bug

grpo_dev
YeAnbang 2025-02-20 17:25:36 +08:00
parent 9379cbd668
commit 0171884664
1 changed files with 1 additions and 1 deletions

View File

@ -140,7 +140,7 @@ class NaiveExperienceMaker(ExperienceMaker):
num_actions = 0
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)