diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index fc3a4930c..04093e705 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -356,6 +356,12 @@ def apply_chat_template_and_mask( truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: + # Format for RL. + gt_answer = None + if "messages" in chat and "gt_answer" in chat: + gt_answer = chat["gt_answer"] + chat = [chat["messages"]] + tokens = [] assistant_mask = [] for i, msg in enumerate(chat): @@ -389,6 +395,11 @@ def apply_chat_template_and_mask( labels = input_ids.clone() labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx + if gt_answer is not None: + gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt") + gt_answer = gt_answer.squeeze(1) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} + return { "input_ids": input_ids, "attention_mask": attention_mask, diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py new file mode 100644 index 000000000..79128b89e --- /dev/null +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -0,0 +1,150 @@ +from contextlib import nullcontext +from typing import Optional + +import ray +import torch +from coati.distributed.consumer import BaseConsumer +from coati.distributed.loss import PolicyLoss +from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +@ray.remote +class GRPOConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.policy_model.gradient_checkpointing_enable() + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) + self.accum_loss = torch.zeros(1, device=self.device) + + # Reference model is initialized from policy model. + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.policy_loss_fn = PolicyLoss() + + def setup(self): + super().setup() + self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.reference_model, *_ = self.booster.boost(self.reference_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + """ + Step data from policy model: + [{ + "input_ids": torch.Tensor, + "attention_mask": torch.Tensor, + "action_mask": torch.Tensor, + "action_log_probs": torch.Tensor, + }, + ...] + Format: + [batch_size, prompt_length + response_length] --- ............. + """ + labels = kwargs["input_ids"].clone() + labels[kwargs["attention_mask"] == 0] = -100 + kwargs["labels"] = labels + sequences = kwargs["input_ids"] + action_mask = kwargs["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = kwargs["action_log_probs"] + assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape + + need_update = (step_idx + 1) % self.num_microbatches == 0 + + ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + with ctx: + policy_model_logits = self.policy_model( + input_ids=kwargs["input_ids"], + attention_mask=kwargs["attention_mask"], + )["logits"] + action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action) + + reference_model_logits = self.reference_model( + input_ids=sequences, + attention_mask=kwargs["attention_mask"], + )["logits"] + reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action) + + # GRPO advantage calculation + kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( + action_mask, dim=-1 + ) + + reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"]) + reward = reward + kl + mean = reward.view(-1, reward.size(0)).mean(dim=1) + std = reward.view(-1, reward.size(0)).std(dim=1) + advantages = (reward - mean) / (std + 1e-4) + # Calculate Loss + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + action_mask, + ) + + loss = loss / self.num_microbatches + self.accum_loss.add_(loss.data) + if not skip_update: + self.booster.backward(loss, self.optimizer) + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + loss_scalar = self.accum_loss.item() + self.accum_loss.zero_() + return loss_scalar + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 95b7d1e80..210ed5036 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -210,6 +210,8 @@ class VLLMInferenceBackend(BaseInferenceBackend): "action_log_probs": log_probs, "action_mask": action_mask, } + if "gt_answer" in kwargs: + data["gt_answer"] = kwargs["gt_answer"] data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 438c46300..5244cc7d9 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional import ray -from .consumer import SimpleConsumer +from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer @@ -68,7 +68,7 @@ def launch_distributed( ) procs.append(producer) for i in range(num_consumer_procs): - consumer = SimpleConsumer.options(num_gpus=1).remote( + consumer = GRPOConsumer.options(num_gpus=1).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py new file mode 100644 index 000000000..c08acba51 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -0,0 +1,44 @@ +from typing import Optional + +import torch +import torch.nn as nn +from coati.distributed.utils import masked_mean + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None: + super().__init__() + self.clip_eps = clip_eps + self.skip_threshold = skip_threshold + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + skip = False + if action_mask is None: + ratio_ = (log_probs - old_log_probs).exp() + else: + ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + + # note that if dropout is disabled (recommanded), ratio will always be 1. + if ratio_.mean() > self.skip_threshold: + skip = True + + ratio = ratio_.clamp(0.0, 10.0) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + loss = loss.mean() + return loss, skip, ratio_.max() diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index f127a1ece..c7b452c54 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,17 +3,13 @@ import torch from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, **kwargs): - # apply varifiable reward - # reward 10 points if the final answer is correct, reward 1 point if format is correct - - gt_answer = kwargs["gt_answer"] +def math_reward_fn(input_ids, gt_answer, **kwargs): tokenizer = kwargs["tokenizer"] - s, e = kwargs["response_start"], kwargs["response_end"] reward = torch.tensor(0.0).to(input_ids.device) if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0)) final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index d1700d86f..fe889a7f4 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -8,33 +8,27 @@ import torch class VerifiableReward: - def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]): - self.reward_fn = reward_fn - self.reward_args = reward_args + def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]): + self.reward_fns = reward_fns + self.kwargs = kwargs def __call__( self, input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - response_start: List[int] = None, - response_end: List[int] = None, - gt_answer: List[str] = None, + gt_answer: List[torch.Tensor] = None, ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) # Initialize reward - reward = torch.zeros(bs, device=input_ids.device) + rewards = torch.zeros(bs, device=input_ids.device) # Loop through reward functions - for reward_fn in self.reward_fn_list: + for reward_fn in self.reward_fns: # Apply the reward function to the entire batch at once reward_batch = torch.stack( [ reward_fn( input_ids[i], - attention_mask[i], - response_start=response_start[i], - response_end=response_end[i], gt_answer=gt_answer[i], **self.kwargs, ) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 533a5ffb2..98b54815b 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -64,3 +64,38 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T log_probs = torch.log_softmax(logits, dim=-1) per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) return per_label_logps.squeeze(-1) + + +def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: + """Calculate action log probs. + + Args: + output (torch.Tensor): Output tensor of Actor.forward.logits. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Action log probs. + """ + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked mean of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the mean. Default is 1. + + Returns: + torch.Tensor: The masked mean tensor. + + """ + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean