add simple grpo

feat/grpo
Tong Li 2025-02-23 22:54:26 +08:00
parent 8e6c9a4ab3
commit ffd3878a1e
8 changed files with 253 additions and 21 deletions

View File

@ -356,6 +356,12 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
chat = [chat["messages"]]
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
@ -389,6 +395,11 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
if gt_answer is not None:
gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,

View File

@ -0,0 +1,150 @@
from contextlib import nullcontext
from typing import Optional
import ray
import torch
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ray.remote
class GRPOConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
self.accum_loss = torch.zeros(1, device=self.device)
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.policy_loss_fn = PolicyLoss()
def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
"input_ids": torch.Tensor,
"attention_mask": torch.Tensor,
"action_mask": torch.Tensor,
"action_log_probs": torch.Tensor,
},
...]
Format:
[batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
sequences = kwargs["input_ids"]
action_mask = kwargs["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = kwargs["action_log_probs"]
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=kwargs["input_ids"],
attention_mask=kwargs["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action)
reference_model_logits = self.reference_model(
input_ids=sequences,
attention_mask=kwargs["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action)
# GRPO advantage calculation
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
action_mask, dim=-1
)
reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
reward = reward + kl
mean = reward.view(-1, reward.size(0)).mean(dim=1)
std = reward.view(-1, reward.size(0)).std(dim=1)
advantages = (reward - mean) / (std + 1e-4)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
action_mask,
)
loss = loss / self.num_microbatches
self.accum_loss.add_(loss.data)
if not skip_update:
self.booster.backward(loss, self.optimizer)
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
self.accum_loss.zero_()
return loss_scalar
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -210,6 +210,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"action_log_probs": log_probs,
"action_mask": action_mask,
}
if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"]
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
@ -68,7 +68,7 @@ def launch_distributed(
)
procs.append(producer)
for i in range(num_consumer_procs):
consumer = SimpleConsumer.options(num_gpus=1).remote(
consumer = GRPOConsumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,

View File

@ -0,0 +1,44 @@
from typing import Optional
import torch
import torch.nn as nn
from coati.distributed.utils import masked_mean
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
super().__init__()
self.clip_eps = clip_eps
self.skip_threshold = skip_threshold
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
skip = True
ratio = ratio_.clamp(0.0, 10.0)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()

View File

@ -3,17 +3,13 @@ import torch
from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
def math_reward_fn(input_ids, gt_answer, **kwargs):
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])

View File

@ -8,33 +8,27 @@ import torch
class VerifiableReward:
def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]):
self.reward_fn = reward_fn
self.reward_args = reward_args
def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
self.reward_fns = reward_fns
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
response_start: List[int] = None,
response_end: List[int] = None,
gt_answer: List[str] = None,
gt_answer: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
# Initialize reward
reward = torch.zeros(bs, device=input_ids.device)
rewards = torch.zeros(bs, device=input_ids.device)
# Loop through reward functions
for reward_fn in self.reward_fn_list:
for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
attention_mask[i],
response_start=response_start[i],
response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)

View File

@ -64,3 +64,38 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked mean of a tensor along a specified dimension.
Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the mean. Default is 1.
Returns:
torch.Tensor: The masked mean tensor.
"""
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean