Merge pull request #6208 from hpcaitech/grpo_dev

[Chat] fix colossalchat bugs
main
YeAnbang 2025-02-20 21:23:16 +08:00 committed by GitHub
commit b9e60559b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 10 additions and 10 deletions

View File

@ -140,7 +140,7 @@ class NaiveExperienceMaker(ExperienceMaker):
num_actions = 0
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)

View File

@ -380,8 +380,8 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.get("accuracy"),
global_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
self.num_train_step += 1
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint

View File

@ -231,7 +231,6 @@ class GRPOTrainer(OLTrainer):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
@ -294,7 +293,7 @@ class GRPOTrainer(OLTrainer):
self.temperature_annealing_scheduler.step_forward()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
@ -327,6 +326,7 @@ class GRPOTrainer(OLTrainer):
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""

View File

@ -256,7 +256,7 @@ class KTOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()

View File

@ -233,7 +233,7 @@ class ORPOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()

View File

@ -220,7 +220,6 @@ class PPOTrainer(OLTrainer):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
self.critic.train()
num_actions = experience.action_log_probs.size(1)
@ -294,7 +293,7 @@ class PPOTrainer(OLTrainer):
self.critic_scheduler.step()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
@ -336,6 +335,7 @@ class PPOTrainer(OLTrainer):
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""

View File

@ -193,7 +193,7 @@ class RewardModelTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()
def _eval(self, epoch):

View File

@ -152,9 +152,9 @@ class SFTTrainer(SLTrainer):
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.num_train_step += 1
# Save checkpoint
if (