update grpo

feat/grpo
Tong Li 2025-02-25 18:12:04 +08:00
parent ffd3878a1e
commit f736d747e3
1 changed files with 50 additions and 20 deletions

View File

@ -3,11 +3,13 @@ from typing import Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, is_rank_0
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ -29,6 +31,8 @@ class GRPOConsumer(BaseConsumer):
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=False,
):
super().__init__(
num_producers,
@ -50,6 +54,8 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -57,6 +63,7 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
@ -70,6 +77,8 @@ class GRPOConsumer(BaseConsumer):
)
self.policy_loss_fn = PolicyLoss()
if is_rank_0():
self.run = wandb.init(project="Colossal-GRPO-Test4")
def setup(self):
super().setup()
@ -87,43 +96,52 @@ class GRPOConsumer(BaseConsumer):
},
...]
Format:
[batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
sequences = kwargs["input_ids"]
action_mask = kwargs["action_mask"]
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = kwargs["action_log_probs"]
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
old_action_log_probs = data["action_log_probs"]
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=kwargs["input_ids"],
attention_mask=kwargs["attention_mask"],
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action)
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
reference_model_logits = self.reference_model(
input_ids=sequences,
attention_mask=kwargs["attention_mask"],
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action)
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
# GRPO advantage calculation
kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1
)
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
reward = kl + reward
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
# [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4)
# GRPO advantage calculation
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1
)
reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
reward = reward + kl
mean = reward.view(-1, reward.size(0)).mean(dim=1)
std = reward.view(-1, reward.size(0)).std(dim=1)
advantages = (reward - mean) / (std + 1e-4)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
@ -133,14 +151,26 @@ class GRPOConsumer(BaseConsumer):
)
loss = loss / self.num_microbatches
self.accum_loss.add_(loss.data)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss)
reward = all_reduce_mean(reward.mean())
kl = all_reduce_mean(kl.mean())
self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data)
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
if is_rank_0():
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
self.run.log(
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
return loss_scalar
def state_dict(self):