diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index fc3a4930c..04093e705 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -356,6 +356,12 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
+ # Format for RL.
+ gt_answer = None
+ if "messages" in chat and "gt_answer" in chat:
+ gt_answer = chat["gt_answer"]
+ chat = [chat["messages"]]
+
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
@@ -389,6 +395,11 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
+ if gt_answer is not None:
+ gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
+ gt_answer = gt_answer.squeeze(1)
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
+
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
new file mode 100644
index 000000000..79128b89e
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -0,0 +1,150 @@
+from contextlib import nullcontext
+from typing import Optional
+
+import ray
+import torch
+from coati.distributed.consumer import BaseConsumer
+from coati.distributed.loss import PolicyLoss
+from coati.distributed.reward.reward_fn import math_reward_fn
+from coati.distributed.reward.verifiable_reward import VerifiableReward
+from coati.distributed.utils import calc_action_log_probs
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+
+
+@ray.remote
+class GRPOConsumer(BaseConsumer):
+ def __init__(
+ self,
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size=1,
+ ):
+ super().__init__(
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size,
+ )
+ path = model_config.pop("path")
+ self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.policy_model.train()
+ self.policy_model.gradient_checkpointing_enable()
+ self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4)
+ self.accum_loss = torch.zeros(1, device=self.device)
+
+ # Reference model is initialized from policy model.
+ self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.reference_model.eval()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
+ self.pad_token_id = self.tokenizer.pad_token_id
+
+ # Initialize verifiable reward.
+ response_format_tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ self.reward_model = VerifiableReward(
+ reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
+ )
+
+ self.policy_loss_fn = PolicyLoss()
+
+ def setup(self):
+ super().setup()
+ self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
+ self.reference_model, *_ = self.booster.boost(self.reference_model)
+
+ def step(self, step_idx: int, **kwargs) -> Optional[float]:
+ """
+ Step data from policy model:
+ [{
+ "input_ids": torch.Tensor,
+ "attention_mask": torch.Tensor,
+ "action_mask": torch.Tensor,
+ "action_log_probs": torch.Tensor,
+ },
+ ...]
+ Format:
+ [batch_size, prompt_length + response_length] --- .............
+ """
+ labels = kwargs["input_ids"].clone()
+ labels[kwargs["attention_mask"] == 0] = -100
+ kwargs["labels"] = labels
+ sequences = kwargs["input_ids"]
+ action_mask = kwargs["action_mask"]
+ num_action = action_mask.shape[1]
+ old_action_log_probs = kwargs["action_log_probs"]
+ assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
+
+ need_update = (step_idx + 1) % self.num_microbatches == 0
+
+ ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
+ with ctx:
+ policy_model_logits = self.policy_model(
+ input_ids=kwargs["input_ids"],
+ attention_mask=kwargs["attention_mask"],
+ )["logits"]
+ action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action)
+
+ reference_model_logits = self.reference_model(
+ input_ids=sequences,
+ attention_mask=kwargs["attention_mask"],
+ )["logits"]
+ reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action)
+
+ # GRPO advantage calculation
+ kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum(
+ action_mask, dim=-1
+ )
+
+ reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"])
+ reward = reward + kl
+ mean = reward.view(-1, reward.size(0)).mean(dim=1)
+ std = reward.view(-1, reward.size(0)).std(dim=1)
+ advantages = (reward - mean) / (std + 1e-4)
+ # Calculate Loss
+ loss, skip_update, _ = self.policy_loss_fn(
+ action_log_probs,
+ old_action_log_probs,
+ advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
+ action_mask,
+ )
+
+ loss = loss / self.num_microbatches
+ self.accum_loss.add_(loss.data)
+ if not skip_update:
+ self.booster.backward(loss, self.optimizer)
+ if need_update:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ loss_scalar = self.accum_loss.item()
+ self.accum_loss.zero_()
+ return loss_scalar
+
+ def state_dict(self):
+ self.policy_model._force_wait_all_gather()
+ model = self.policy_model.unwrap()
+ state_dict = model.state_dict()
+ return state_dict
diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py
index 95b7d1e80..210ed5036 100644
--- a/applications/ColossalChat/coati/distributed/inference_backend.py
+++ b/applications/ColossalChat/coati/distributed/inference_backend.py
@@ -210,6 +210,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"action_log_probs": log_probs,
"action_mask": action_mask,
}
+ if "gt_answer" in kwargs:
+ data["gt_answer"] = kwargs["gt_answer"]
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index 438c46300..5244cc7d9 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
import ray
-from .consumer import SimpleConsumer
+from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
@@ -68,7 +68,7 @@ def launch_distributed(
)
procs.append(producer)
for i in range(num_consumer_procs):
- consumer = SimpleConsumer.options(num_gpus=1).remote(
+ consumer = GRPOConsumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py
new file mode 100644
index 000000000..c08acba51
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/loss.py
@@ -0,0 +1,44 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from coati.distributed.utils import masked_mean
+
+
+class PolicyLoss(nn.Module):
+ """
+ Policy Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+ self.skip_threshold = skip_threshold
+
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ skip = False
+ if action_mask is None:
+ ratio_ = (log_probs - old_log_probs).exp()
+ else:
+ ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
+
+ # note that if dropout is disabled (recommanded), ratio will always be 1.
+ if ratio_.mean() > self.skip_threshold:
+ skip = True
+
+ ratio = ratio_.clamp(0.0, 10.0)
+ surr1 = ratio * advantages
+ surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
+ loss = -torch.min(surr1, surr2)
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ else:
+ loss = loss.mean(dim=1)
+ loss = loss.mean()
+ return loss, skip, ratio_.max()
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
index f127a1ece..c7b452c54 100644
--- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -3,17 +3,13 @@ import torch
from .reward_utils import extract_solution, validate_response_structure
-def math_reward_fn(input_ids, **kwargs):
- # apply varifiable reward
- # reward 10 points if the final answer is correct, reward 1 point if format is correct
-
- gt_answer = kwargs["gt_answer"]
+def math_reward_fn(input_ids, gt_answer, **kwargs):
tokenizer = kwargs["tokenizer"]
- s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
- decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
+ decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
+ gt_answer = tokenizer.decode(gt_answer.squeeze(0))
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
index d1700d86f..fe889a7f4 100644
--- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
+++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
@@ -8,33 +8,27 @@ import torch
class VerifiableReward:
- def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]):
- self.reward_fn = reward_fn
- self.reward_args = reward_args
+ def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
+ self.reward_fns = reward_fns
+ self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
- attention_mask: torch.LongTensor,
- response_start: List[int] = None,
- response_end: List[int] = None,
- gt_answer: List[str] = None,
+ gt_answer: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
# Initialize reward
- reward = torch.zeros(bs, device=input_ids.device)
+ rewards = torch.zeros(bs, device=input_ids.device)
# Loop through reward functions
- for reward_fn in self.reward_fn_list:
+ for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
- attention_mask[i],
- response_start=response_start[i],
- response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)
diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py
index 533a5ffb2..98b54815b 100644
--- a/applications/ColossalChat/coati/distributed/utils.py
+++ b/applications/ColossalChat/coati/distributed/utils.py
@@ -64,3 +64,38 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
+
+
+def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
+ """Calculate action log probs.
+
+ Args:
+ output (torch.Tensor): Output tensor of Actor.forward.logits.
+ sequences (torch.LongTensor): Input sequences.
+ num_actions (int): Number of actions.
+
+ Returns:
+ torch.Tensor: Action log probs.
+ """
+ log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ return log_probs[:, -num_actions:]
+
+
+def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
+ """
+ Compute the masked mean of a tensor along a specified dimension.
+
+ Args:
+ tensor (torch.Tensor): The input tensor.
+ mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
+ dim (int, optional): The dimension along which to compute the mean. Default is 1.
+
+ Returns:
+ torch.Tensor: The masked mean tensor.
+
+ """
+ tensor = tensor * mask
+ tensor = tensor.sum(dim=dim)
+ mask_sum = mask.sum(dim=dim)
+ mean = tensor / (mask_sum + 1e-8)
+ return mean