ColossalAI/applications/ColossalChat/coati/distributed/ppo_consumer.py

263 lines
11 KiB
Python

from contextlib import nullcontext
from typing import Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ray.remote
class PPOConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=1,
gamma:float=1.0,
lam:float=0.95,
kl_coef:float=0.05,
use_wandb=True,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
self.gamma = gamma
self.lam = lam
self.kl_coef = kl_coef
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.critic_model = Critic(path, **model_config)
self.critic_model.model.gradient_checkpointing_enable()
self.critic_model.train()
# Disable dropout
disable_dropout(self.policy_model)
disable_dropout(self.critic_model)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_advantage = torch.zeros(1, device=self.device)
self.accum_critic_loss = torch.zeros(1, device=self.device)
self.accum_count = 0
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.policy_loss_fn = PolicyLoss()
self.critic_loss_fn = ValueLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)
def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
self.critic_model, self.critic_optimizer
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
"input_ids": torch.Tensor,
"attention_mask": torch.Tensor,
"action_mask": torch.Tensor,
"action_log_probs": torch.Tensor,
},
...]
Format:
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"].detach()
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
value = self.critic_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)
value = value[:, -num_action -1: -1] * action_mask
r = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
)
# Calculate advantages
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
lastgaelam = 0
advantage_reversed = []
for t in reversed(range(num_action)):
nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantage_reversed.append(lastgaelam)
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
advantage = advantage.detach()
# KL divergence for logging
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantage,
0, # kl is already included in the advantage
action_mask,
)
# Critic Loss
# Hack: use the current value to approximate the old value, should be old value mathematically
critic_loss = self.critic_loss_fn(
value,
value.detach(),
advantage,
action_mask=action_mask,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
self.critic_booster.backward(critic_loss, self.critic_optimizer)
loss = all_reduce_mean(loss, self.plugin)
critic_loss = all_reduce_mean(critic_loss, self.plugin)
r_mean = all_reduce_mean(r.mean(), self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
advantage = all_reduce_mean(advantage.mean(), self.plugin)
self.accum_loss.add_(loss.data)
self.accum_critic_loss.add_(critic_loss.data)
self.accum_advantage.add_(advantage.data)
self.accum_reward.add_(r_mean.data)
self.accum_kl.add_(kl.data)
self.accum_count += 1
if self.rank == 0:
print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
if self.rank == 0:
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"Reward:",
self.accum_reward.item() / self.accum_count,
"KL:",
self.accum_kl.item() / self.accum_count,
)
if self.global_step % 3 == 0:
for i in range(min(3, data["input_ids"].shape[0])):
response_decoded_for_logging = self.tokenizer.decode(
data["input_ids"][i], skip_special_tokens=True
)
response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
"train/advantage": self.accum_advantage.item() / self.accum_count,
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
self.accum_advantage.zero_()
self.accum_critic_loss.zero_()
self.accum_count = 0
self.global_step += 1
return loss_scalar
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict