add response length

grpo-latest-npu
Tong Li 2025-03-11 13:06:09 +08:00
parent abca66e69f
commit 47d6493778
1 changed files with 19 additions and 9 deletions

View File

@ -59,6 +59,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
# Reference model is initialized from policy model.
@ -83,7 +84,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True)
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
def setup(self):
super().setup()
@ -109,6 +110,7 @@ class GRPOConsumer(BaseConsumer):
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
need_update = (step_idx + 1) % self.num_microbatches == 0
@ -168,6 +170,7 @@ class GRPOConsumer(BaseConsumer):
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
# Calculate accumulate value.
self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data)
@ -175,6 +178,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
@ -184,32 +188,38 @@ class GRPOConsumer(BaseConsumer):
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"Reward:",
"\nReward:",
self.accum_reward.item() / self.accum_count,
"KL:",
self.accum_kl.item() / self.accum_count,
"Format Reward:",
"\nFormat Reward:",
self.accum_format_reward.item() / self.accum_count,
"Acc Reward:",
"\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"Advantages:",
"\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
)
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/response_length": self.accum_response_length.item() / self.accum_count,
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar