mirror of https://github.com/hpcaitech/ColossalAI
fix num_train_step update
parent
0171884664
commit
53834b74b9
|
@ -380,8 +380,8 @@ class DPOTrainer(SLTrainer):
|
||||||
self.accumulative_meter.get("accuracy"),
|
self.accumulative_meter.get("accuracy"),
|
||||||
global_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.num_train_step += 1
|
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
|
|
|
@ -231,7 +231,6 @@ class GRPOTrainer(OLTrainer):
|
||||||
experience:
|
experience:
|
||||||
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||||
"""
|
"""
|
||||||
self.num_train_step += 1
|
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
num_actions = experience.action_log_probs.size(1)
|
num_actions = experience.action_log_probs.size(1)
|
||||||
# policy loss
|
# policy loss
|
||||||
|
@ -294,7 +293,7 @@ class GRPOTrainer(OLTrainer):
|
||||||
self.temperature_annealing_scheduler.step_forward()
|
self.temperature_annealing_scheduler.step_forward()
|
||||||
|
|
||||||
# preparing logging model output and corresponding rewards.
|
# preparing logging model output and corresponding rewards.
|
||||||
if self.num_train_step % 10 == 1:
|
if self.num_train_step % 10 == 0:
|
||||||
response_text = self.experience_maker.tokenizer.batch_decode(
|
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||||
experience.sequences, skip_special_tokens=True
|
experience.sequences, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
@ -327,6 +326,7 @@ class GRPOTrainer(OLTrainer):
|
||||||
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
|
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
|
||||||
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
|
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
def _learn(self, update_step: int):
|
def _learn(self, update_step: int):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -220,7 +220,6 @@ class PPOTrainer(OLTrainer):
|
||||||
experience:
|
experience:
|
||||||
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||||
"""
|
"""
|
||||||
self.num_train_step += 1
|
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
num_actions = experience.action_log_probs.size(1)
|
num_actions = experience.action_log_probs.size(1)
|
||||||
|
@ -294,7 +293,7 @@ class PPOTrainer(OLTrainer):
|
||||||
self.critic_scheduler.step()
|
self.critic_scheduler.step()
|
||||||
|
|
||||||
# preparing logging model output and corresponding rewards.
|
# preparing logging model output and corresponding rewards.
|
||||||
if self.num_train_step % 10 == 1:
|
if self.num_train_step % 10 == 0:
|
||||||
response_text = self.experience_maker.tokenizer.batch_decode(
|
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||||
experience.sequences, skip_special_tokens=True
|
experience.sequences, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
@ -336,6 +335,7 @@ class PPOTrainer(OLTrainer):
|
||||||
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
||||||
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
def _learn(self, update_step: int):
|
def _learn(self, update_step: int):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -152,9 +152,9 @@ class SFTTrainer(SLTrainer):
|
||||||
if self.writer:
|
if self.writer:
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
|
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
|
||||||
self.num_train_step += 1
|
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in New Issue