mirror of https://github.com/hpcaitech/ColossalAI
add simple grpo
parent
8e6c9a4ab3
commit
ffd3878a1e
|
@ -356,6 +356,12 @@ def apply_chat_template_and_mask(
|
||||||
truncation: bool = True,
|
truncation: bool = True,
|
||||||
ignore_idx: int = -100,
|
ignore_idx: int = -100,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
# Format for RL.
|
||||||
|
gt_answer = None
|
||||||
|
if "messages" in chat and "gt_answer" in chat:
|
||||||
|
gt_answer = chat["gt_answer"]
|
||||||
|
chat = [chat["messages"]]
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
assistant_mask = []
|
assistant_mask = []
|
||||||
for i, msg in enumerate(chat):
|
for i, msg in enumerate(chat):
|
||||||
|
@ -389,6 +395,11 @@ def apply_chat_template_and_mask(
|
||||||
labels = input_ids.clone()
|
labels = input_ids.clone()
|
||||||
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
||||||
|
|
||||||
|
if gt_answer is not None:
|
||||||
|
gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
|
||||||
|
gt_answer = gt_answer.squeeze(1)
|
||||||
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
from coati.distributed.consumer import BaseConsumer
|
||||||
|
from coati.distributed.loss import PolicyLoss
|
||||||
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||||
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class GRPOConsumer(BaseConsumer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size=1,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size,
|
||||||
|
)
|
||||||
|
path = model_config.pop("path")
|
||||||
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.policy_model.train()
|
||||||
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
|
||||||
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
|
|
||||||
|
# Reference model is initialized from policy model.
|
||||||
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.reference_model.eval()
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
# Initialize verifiable reward.
|
||||||
|
response_format_tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
self.reward_model = VerifiableReward(
|
||||||
|
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
self.policy_loss_fn = PolicyLoss()
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
super().setup()
|
||||||
|
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
|
||||||
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
|
|
||||||
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Step data from policy model:
|
||||||
|
[{
|
||||||
|
"input_ids": torch.Tensor,
|
||||||
|
"attention_mask": torch.Tensor,
|
||||||
|
"action_mask": torch.Tensor,
|
||||||
|
"action_log_probs": torch.Tensor,
|
||||||
|
},
|
||||||
|
...]
|
||||||
|
Format:
|
||||||
|
[batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
|
"""
|
||||||
|
labels = kwargs["input_ids"].clone()
|
||||||
|
labels[kwargs["attention_mask"] == 0] = -100
|
||||||
|
kwargs["labels"] = labels
|
||||||
|
sequences = kwargs["input_ids"]
|
||||||
|
action_mask = kwargs["action_mask"]
|
||||||
|
num_action = action_mask.shape[1]
|
||||||
|
old_action_log_probs = kwargs["action_log_probs"]
|
||||||
|
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
|
||||||
|
|
||||||
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
|
||||||
|
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
|
with ctx:
|
||||||
|
policy_model_logits = self.policy_model(
|
||||||
|
input_ids=kwargs["input_ids"],
|
||||||
|
attention_mask=kwargs["attention_mask"],
|
||||||
|
)["logits"]
|
||||||
|
action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action)
|
||||||
|
|
||||||
|
reference_model_logits = self.reference_model(
|
||||||
|
input_ids=sequences,
|
||||||
|
attention_mask=kwargs["attention_mask"],
|
||||||
|
)["logits"]
|
||||||
|
reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action)
|
||||||
|
|
||||||
|
# GRPO advantage calculation
|
||||||
|
kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
|
||||||
|
action_mask, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
|
||||||
|
reward = reward + kl
|
||||||
|
mean = reward.view(-1, reward.size(0)).mean(dim=1)
|
||||||
|
std = reward.view(-1, reward.size(0)).std(dim=1)
|
||||||
|
advantages = (reward - mean) / (std + 1e-4)
|
||||||
|
# Calculate Loss
|
||||||
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
|
action_log_probs,
|
||||||
|
old_action_log_probs,
|
||||||
|
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
|
action_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = loss / self.num_microbatches
|
||||||
|
self.accum_loss.add_(loss.data)
|
||||||
|
if not skip_update:
|
||||||
|
self.booster.backward(loss, self.optimizer)
|
||||||
|
if need_update:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss_scalar = self.accum_loss.item()
|
||||||
|
self.accum_loss.zero_()
|
||||||
|
return loss_scalar
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
self.policy_model._force_wait_all_gather()
|
||||||
|
model = self.policy_model.unwrap()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
return state_dict
|
|
@ -210,6 +210,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||||
"action_log_probs": log_probs,
|
"action_log_probs": log_probs,
|
||||||
"action_mask": action_mask,
|
"action_mask": action_mask,
|
||||||
}
|
}
|
||||||
|
if "gt_answer" in kwargs:
|
||||||
|
data["gt_answer"] = kwargs["gt_answer"]
|
||||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from .consumer import SimpleConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ def launch_distributed(
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
consumer = SimpleConsumer.options(num_gpus=1).remote(
|
consumer = GRPOConsumer.options(num_gpus=1).remote(
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_episodes=num_episodes,
|
num_episodes=num_episodes,
|
||||||
rank=i,
|
rank=i,
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from coati.distributed.utils import masked_mean
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Policy Loss for PPO
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.clip_eps = clip_eps
|
||||||
|
self.skip_threshold = skip_threshold
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
log_probs: torch.Tensor,
|
||||||
|
old_log_probs: torch.Tensor,
|
||||||
|
advantages: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
skip = False
|
||||||
|
if action_mask is None:
|
||||||
|
ratio_ = (log_probs - old_log_probs).exp()
|
||||||
|
else:
|
||||||
|
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||||
|
|
||||||
|
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||||
|
if ratio_.mean() > self.skip_threshold:
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
ratio = ratio_.clamp(0.0, 10.0)
|
||||||
|
surr1 = ratio * advantages
|
||||||
|
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
loss = -torch.min(surr1, surr2)
|
||||||
|
if action_mask is not None:
|
||||||
|
loss = masked_mean(loss, action_mask)
|
||||||
|
else:
|
||||||
|
loss = loss.mean(dim=1)
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss, skip, ratio_.max()
|
|
@ -3,17 +3,13 @@ import torch
|
||||||
from .reward_utils import extract_solution, validate_response_structure
|
from .reward_utils import extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, **kwargs):
|
||||||
# apply varifiable reward
|
|
||||||
# reward 10 points if the final answer is correct, reward 1 point if format is correct
|
|
||||||
|
|
||||||
gt_answer = kwargs["gt_answer"]
|
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
|
||||||
reward = torch.tensor(0.0).to(input_ids.device)
|
reward = torch.tensor(0.0).to(input_ids.device)
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
return reward
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||||
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
|
||||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||||
|
|
||||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||||
|
|
|
@ -8,33 +8,27 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
class VerifiableReward:
|
class VerifiableReward:
|
||||||
def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]):
|
def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
|
||||||
self.reward_fn = reward_fn
|
self.reward_fns = reward_fns
|
||||||
self.reward_args = reward_args
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
attention_mask: torch.LongTensor,
|
gt_answer: List[torch.Tensor] = None,
|
||||||
response_start: List[int] = None,
|
|
||||||
response_end: List[int] = None,
|
|
||||||
gt_answer: List[str] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Get batch size
|
# Get batch size
|
||||||
bs = input_ids.size(0)
|
bs = input_ids.size(0)
|
||||||
# Initialize reward
|
# Initialize reward
|
||||||
reward = torch.zeros(bs, device=input_ids.device)
|
rewards = torch.zeros(bs, device=input_ids.device)
|
||||||
|
|
||||||
# Loop through reward functions
|
# Loop through reward functions
|
||||||
for reward_fn in self.reward_fn_list:
|
for reward_fn in self.reward_fns:
|
||||||
# Apply the reward function to the entire batch at once
|
# Apply the reward function to the entire batch at once
|
||||||
reward_batch = torch.stack(
|
reward_batch = torch.stack(
|
||||||
[
|
[
|
||||||
reward_fn(
|
reward_fn(
|
||||||
input_ids[i],
|
input_ids[i],
|
||||||
attention_mask[i],
|
|
||||||
response_start=response_start[i],
|
|
||||||
response_end=response_end[i],
|
|
||||||
gt_answer=gt_answer[i],
|
gt_answer=gt_answer[i],
|
||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,3 +64,38 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
||||||
log_probs = torch.log_softmax(logits, dim=-1)
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
return per_label_logps.squeeze(-1)
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||||
|
"""Calculate action log probs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||||
|
sequences (torch.LongTensor): Input sequences.
|
||||||
|
num_actions (int): Number of actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Action log probs.
|
||||||
|
"""
|
||||||
|
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
|
||||||
|
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the masked mean of a tensor along a specified dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The input tensor.
|
||||||
|
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
||||||
|
dim (int, optional): The dimension along which to compute the mean. Default is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masked mean tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tensor = tensor * mask
|
||||||
|
tensor = tensor.sum(dim=dim)
|
||||||
|
mask_sum = mask.sum(dim=dim)
|
||||||
|
mean = tensor / (mask_sum + 1e-8)
|
||||||
|
return mean
|
||||||
|
|
Loading…
Reference in New Issue