mirror of https://github.com/hpcaitech/ColossalAI
add response length
parent
abca66e69f
commit
47d6493778
|
@ -59,6 +59,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
|
@ -83,7 +84,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
if use_wandb and self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
|
||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
|
@ -109,6 +110,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
|
@ -168,6 +170,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.accum_reward.add_(reward.data)
|
||||
|
@ -175,6 +178,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
|
@ -184,32 +188,38 @@ class GRPOConsumer(BaseConsumer):
|
|||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"Reward:",
|
||||
"\nReward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"KL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"Format Reward:",
|
||||
"\nFormat Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"Acc Reward:",
|
||||
"\nAcc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"Advantages:",
|
||||
"\nKL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"\nAdvantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
"\nResponse Length:",
|
||||
self.accum_response_length.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
|
|
Loading…
Reference in New Issue