update grpo

feat/grpo
Tong Li 2025-02-25 18:12:04 +08:00
parent ffd3878a1e
commit f736d747e3
1 changed files with 50 additions and 20 deletions

View File

@ -3,11 +3,13 @@ from typing import Optional
import ray import ray
import torch import torch
import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, is_rank_0
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -29,6 +31,8 @@ class GRPOConsumer(BaseConsumer):
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, microbatch_size=1,
num_generations=4,
use_wandb=False,
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -50,6 +54,8 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -57,6 +63,7 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = { response_format_tags = {
@ -70,6 +77,8 @@ class GRPOConsumer(BaseConsumer):
) )
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
if is_rank_0():
self.run = wandb.init(project="Colossal-GRPO-Test4")
def setup(self): def setup(self):
super().setup() super().setup()
@ -87,43 +96,52 @@ class GRPOConsumer(BaseConsumer):
}, },
...] ...]
Format: Format:
[batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>. [batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
""" """
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100 # Reshape to [batch_size x num_of_generation, prompt_length + response_length]
kwargs["labels"] = labels data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
sequences = kwargs["input_ids"] action_mask = data["action_mask"]
action_mask = kwargs["action_mask"]
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
old_action_log_probs = kwargs["action_log_probs"] old_action_log_probs = data["action_log_probs"]
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx: with ctx:
policy_model_logits = self.policy_model( policy_model_logits = self.policy_model(
input_ids=kwargs["input_ids"], input_ids=data["input_ids"],
attention_mask=kwargs["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action) action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
reference_model_logits = self.reference_model( reference_model_logits = self.reference_model(
input_ids=sequences, input_ids=data["input_ids"],
attention_mask=kwargs["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action) reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
# GRPO advantage calculation
kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1
)
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
reward = kl + reward
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
# [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4)
# GRPO advantage calculation # GRPO advantage calculation
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1 action_mask, dim=-1
) )
reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
reward = reward + kl
mean = reward.view(-1, reward.size(0)).mean(dim=1)
std = reward.view(-1, reward.size(0)).std(dim=1)
advantages = (reward - mean) / (std + 1e-4)
# Calculate Loss # Calculate Loss
loss, skip_update, _ = self.policy_loss_fn( loss, skip_update, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
@ -133,14 +151,26 @@ class GRPOConsumer(BaseConsumer):
) )
loss = loss / self.num_microbatches loss = loss / self.num_microbatches
self.accum_loss.add_(loss.data)
if not skip_update: if not skip_update:
self.booster.backward(loss, self.optimizer) self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss)
reward = all_reduce_mean(reward.mean())
kl = all_reduce_mean(kl.mean())
self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data)
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if is_rank_0():
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
self.run.log(
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
)
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
return loss_scalar return loss_scalar
def state_dict(self): def state_dict(self):