From f736d747e3b3c2601a15d240db669607e8aacba9 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 25 Feb 2025 18:12:04 +0800 Subject: [PATCH] update grpo --- .../coati/distributed/grpo_consumer.py | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 79128b89e..d88df2360 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -3,11 +3,13 @@ from typing import Optional import ray import torch +import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs +from coati.trainer.utils import all_reduce_mean, is_rank_0 from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.optimizer import HybridAdam @@ -29,6 +31,8 @@ class GRPOConsumer(BaseConsumer): model_config, plugin_config, microbatch_size=1, + num_generations=4, + use_wandb=False, ): super().__init__( num_producers, @@ -50,6 +54,8 @@ class GRPOConsumer(BaseConsumer): self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) self.accum_loss = torch.zeros(1, device=self.device) + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_kl = torch.zeros(1, device=self.device) # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -57,6 +63,7 @@ class GRPOConsumer(BaseConsumer): self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations # Initialize verifiable reward. response_format_tags = { @@ -70,6 +77,8 @@ class GRPOConsumer(BaseConsumer): ) self.policy_loss_fn = PolicyLoss() + if is_rank_0(): + self.run = wandb.init(project="Colossal-GRPO-Test4") def setup(self): super().setup() @@ -87,43 +96,52 @@ class GRPOConsumer(BaseConsumer): }, ...] Format: - [batch_size, prompt_length + response_length] --- ............. + [batch_size, num_of_generation, prompt_length + response_length] --- ............. """ - labels = kwargs["input_ids"].clone() - labels[kwargs["attention_mask"] == 0] = -100 - kwargs["labels"] = labels - sequences = kwargs["input_ids"] - action_mask = kwargs["action_mask"] + + # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + action_mask = data["action_mask"] num_action = action_mask.shape[1] - old_action_log_probs = kwargs["action_log_probs"] - assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape + old_action_log_probs = data["action_log_probs"] need_update = (step_idx + 1) % self.num_microbatches == 0 ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: policy_model_logits = self.policy_model( - input_ids=kwargs["input_ids"], - attention_mask=kwargs["attention_mask"], + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action) + action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) reference_model_logits = self.reference_model( - input_ids=sequences, - attention_mask=kwargs["attention_mask"], + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action) + reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + + # GRPO advantage calculation + kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( + action_mask, dim=-1 + ) + + reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"]) + reward = kl + reward + # [batch_size, num_generations] + group_reward = reward.view(-1, self.num_generations) + + # [batch_size x num_generations] + reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4) # GRPO advantage calculation kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( action_mask, dim=-1 ) - reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"]) - reward = reward + kl - mean = reward.view(-1, reward.size(0)).mean(dim=1) - std = reward.view(-1, reward.size(0)).std(dim=1) - advantages = (reward - mean) / (std + 1e-4) # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, @@ -133,14 +151,26 @@ class GRPOConsumer(BaseConsumer): ) loss = loss / self.num_microbatches - self.accum_loss.add_(loss.data) if not skip_update: self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss) + reward = all_reduce_mean(reward.mean()) + kl = all_reduce_mean(kl.mean()) + self.accum_loss.add_(loss.data) + self.accum_reward.add_(reward.data) + self.accum_kl.add_(kl.data) if need_update: self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() + if is_rank_0(): + print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) + self.run.log( + {"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()} + ) self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_kl.zero_() return loss_scalar def state_dict(self):