diff --git a/.gitignore b/.gitignore index 8bc74b4c8..af7d5c3fa 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,5 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs +applications/ColossalChat/wandb applications/ColossalChat/tests/logs diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3..646853829 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -14,6 +14,7 @@ from colossalai.booster.plugin import HybridParallelPlugin from colossalai.initialize import launch from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device +from colossalai.shardformer.policies.auto_policy import get_autopolicy from .comm import ray_broadcast_tensor_dict from .utils import bind_batch, post_recv, unbind_batch @@ -76,6 +77,10 @@ class BaseConsumer: plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) + if hasattr(self, "critic_model"): + plugin_config.update({"custom_policy": get_autopolicy(self.critic_model.model)}) + self.critic_plugin = HybridParallelPlugin(**plugin_config) + self.critic_booster = Booster(plugin=self.critic_plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index bc0ae5c36..1ca40bf65 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -60,6 +60,7 @@ class TransformersInferenceBackend(BaseInferenceBackend): self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config) self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) + self.generate_config["tokenizer"] = tokenizer self.tokenizer = tokenizer @torch.no_grad() @@ -76,21 +77,26 @@ class TransformersInferenceBackend(BaseInferenceBackend): action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) action_log_probs = torch.cat(action_log_probs, dim=1) # get action mask + response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device()) action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) if self.tokenizer.eos_token_id is not None: for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): action_mask[indices[0], indices[1] + 1 :] = 0 + response_idx[:, 0] = input_len + response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1 if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0) attention_mask = torch.cat((attention_mask, action_mask), dim=1) + data = { "input_ids": out.sequences, "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, + "response_idx": response_idx } return data @@ -154,7 +160,6 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=4, ) def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): @@ -167,7 +172,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer - self.num_generations = self.FORCE_GENERATE_CONFIG["n"] + self.num_generations = generate_config["n"] @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8581ff586..bb492a047 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,11 +4,13 @@ import ray from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer +from .ppo_consumer import PPOConsumer from .producer import SimpleProducer ALGO_MAP = { "Simple": SimpleConsumer, "GRPO": GRPOConsumer, + "PPO": PPOConsumer } diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 222540b92..e63209f34 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -45,3 +45,31 @@ class PolicyLoss(nn.Module): loss = loss.mean(dim=1) loss = loss.mean() return loss, skip, ratio_.max() + + +class ValueLoss(nn.Module): + """ + Value Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + advantage: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + returns = advantage + old_values + values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) + surr1 = (values_clipped - returns) ** 2 + surr2 = (values - returns) ** 2 + if action_mask is not None: + # loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask) + loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask)) + else: + loss = torch.mean(torch.max(surr1, surr2)) + return 0.5 * loss \ No newline at end of file diff --git a/applications/ColossalChat/coati/distributed/ppo_consumer.py b/applications/ColossalChat/coati/distributed/ppo_consumer.py new file mode 100644 index 000000000..d7d80149c --- /dev/null +++ b/applications/ColossalChat/coati/distributed/ppo_consumer.py @@ -0,0 +1,262 @@ +from contextlib import nullcontext +from typing import Optional + +import ray +import torch +import wandb +from coati.distributed.consumer import BaseConsumer +from coati.distributed.loss import PolicyLoss, ValueLoss +from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo +from coati.trainer.utils import all_reduce_mean +from coati.models import Critic, disable_dropout +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +@ray.remote +class PPOConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + num_generations=1, + gamma:float=1.0, + lam:float=0.95, + kl_coef:float=0.05, + use_wandb=True, + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + self.gamma = gamma + self.lam = lam + self.kl_coef = kl_coef + + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.policy_model.gradient_checkpointing_enable() + self.critic_model = Critic(path, **model_config) + self.critic_model.model.gradient_checkpointing_enable() + self.critic_model.train() + + # Disable dropout + disable_dropout(self.policy_model) + disable_dropout(self.critic_model) + + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) + self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6) + + self.accum_loss = torch.zeros(1, device=self.device) + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_kl = torch.zeros(1, device=self.device) + self.accum_advantage = torch.zeros(1, device=self.device) + self.accum_critic_loss = torch.zeros(1, device=self.device) + self.accum_count = 0 + + # Reference model is initialized from policy model. + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.policy_loss_fn = PolicyLoss() + self.critic_loss_fn = ValueLoss() + self.global_step = 0 + if use_wandb and self.rank == 0: + self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True) + + def setup(self): + super().setup() + self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost( + self.critic_model, self.critic_optimizer + ) + self.reference_model, *_ = self.booster.boost(self.reference_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + """ + Step data from policy model: + [{ + "input_ids": torch.Tensor, + "attention_mask": torch.Tensor, + "action_mask": torch.Tensor, + "action_log_probs": torch.Tensor, + }, + ...] + Format: + [batch_size, num_of_generation, prompt_length + response_length] --- ............. + """ + + # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + action_mask = data["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = data["action_log_probs"].detach() + + need_update = (step_idx + 1) % self.num_microbatches == 0 + + ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + with ctx: + policy_model_logits = self.policy_model( + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], + )["logits"] + action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) + + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], + )["logits"] + reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + + value = self.critic_model( + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], + ) + value = value[:, -num_action -1: -1] * action_mask + + r = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) + reward, kl = compute_reward_ppo( + r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask + ) + + # Calculate advantages + # reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0 + lastgaelam = 0 + advantage_reversed = [] + for t in reversed(range(num_action)): + nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0 + delta = reward[:, t] + self.gamma * nextvalues - value[:, t] + lastgaelam = delta + self.gamma * self.lam * lastgaelam + advantage_reversed.append(lastgaelam) + advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask + advantage = advantage.detach() + + # KL divergence for logging + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) + + # Calculate Loss + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantage, + 0, # kl is already included in the advantage + action_mask, + ) + + # Critic Loss + # Hack: use the current value to approximate the old value, should be old value mathematically + critic_loss = self.critic_loss_fn( + value, + value.detach(), + advantage, + action_mask=action_mask, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + self.critic_booster.backward(critic_loss, self.critic_optimizer) + loss = all_reduce_mean(loss, self.plugin) + critic_loss = all_reduce_mean(critic_loss, self.plugin) + r_mean = all_reduce_mean(r.mean(), self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + advantage = all_reduce_mean(advantage.mean(), self.plugin) + self.accum_loss.add_(loss.data) + self.accum_critic_loss.add_(critic_loss.data) + self.accum_advantage.add_(advantage.data) + self.accum_reward.add_(r_mean.data) + self.accum_kl.add_(kl.data) + self.accum_count += 1 + if self.rank == 0: + print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}") + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + self.critic_optimizer.step() + self.critic_optimizer.zero_grad() + loss_scalar = self.accum_loss.item() + if self.rank == 0: + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "Reward:", + self.accum_reward.item() / self.accum_count, + "KL:", + self.accum_kl.item() / self.accum_count, + ) + if self.global_step % 3 == 0: + for i in range(min(3, data["input_ids"].shape[0])): + response_decoded_for_logging = self.tokenizer.decode( + data["input_ids"][i], skip_special_tokens=True + ) + response_reward_for_logging = r[i] + print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}") + self.wandb_run.log( + { + "train/loss": self.accum_loss.item() / self.accum_count, + "train/reward": self.accum_reward.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/critic_loss": self.accum_critic_loss.item() / self.accum_count, + "train/advantage": self.accum_advantage.item() / self.accum_count, + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_kl.zero_() + self.accum_advantage.zero_() + self.accum_critic_loss.zero_() + self.accum_count = 0 + self.global_step += 1 + return loss_scalar + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 3e4a5277a..eb0621180 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -154,6 +154,11 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): + if self.backend_cls.__name__ == "TransformersInferenceBackend": + gt_answer = kwargs.pop("gt_answer") + out = self.model.generate(input_ids, attention_mask, **kwargs) + out["gt_answer"] = gt_answer.to(out["input_ids"].device) + return out return self.model.generate(input_ids, attention_mask, **kwargs) def load_state_dict(self, state_dict): diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 9e6d1066e..7b34c6d0b 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -19,8 +19,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): return reward else: reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 2.0 + # if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + # reward = reward + 2.0 return reward diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 98b54815b..837f48218 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import torch @@ -99,3 +99,30 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean + +def compute_reward_ppo( + r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + reward_eps=5, +) -> torch.Tensor: + """ + Args: + log_probs: [batch_size, response_length] + log_probs_base: [batch_size, response_length] + action_mask: [batch_size, response_length] + r: float + Returns: + reward: [batch_size, response_length] + """ + log_ratio = log_probs - log_probs_base # address numerical instability issue + kl = -kl_coef * log_ratio * action_mask + reward = kl + r_clip = torch.clamp(r, -reward_eps, reward_eps) + for i in range(action_mask.size(0)): + assert action_mask[i].sum() > 0 + reward[i, : action_mask[i].sum()] += r_clip[i] + reward[i, action_mask[i].sum() :] *= 0 + return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask \ No newline at end of file diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 3d4b8a575..b39ddc4d0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -7,6 +7,7 @@ from coati.distributed.launch import launch_distributed if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") + parser.add_argument("-rm", "--reward_model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) @@ -15,7 +16,7 @@ if __name__ == "__main__": parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GPRO", "PPO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") @@ -27,26 +28,27 @@ if __name__ == "__main__": top_p=0.8, ) - if args.backend == "transformers": + if args.backend == "transformers": inference_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, ) ) train_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, use_cache=False, ) ) generate_config.update( dict( - max_length=512, + max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, + stop_strings=[""] ) ) elif args.backend == "vllm": @@ -57,12 +59,13 @@ if __name__ == "__main__": ) generate_config.update( dict( - max_tokens=2048, + max_tokens=512, ignore_eos=True, include_stop_str_in_output=True, stop=[""], - temperature=0.7, + temperature=0.5, top_p=0.95, + n=1, ) ) else: