mirror of https://github.com/hpcaitech/ColossalAI
263 lines
11 KiB
Python
263 lines
11 KiB
Python
from contextlib import nullcontext
|
|
from typing import Optional
|
|
|
|
import ray
|
|
import torch
|
|
import wandb
|
|
from coati.distributed.consumer import BaseConsumer
|
|
from coati.distributed.loss import PolicyLoss, ValueLoss
|
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
|
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
|
from coati.trainer.utils import all_reduce_mean
|
|
from coati.models import Critic, disable_dropout
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
|
@ray.remote
|
|
class PPOConsumer(BaseConsumer):
|
|
def __init__(
|
|
self,
|
|
num_producers,
|
|
num_episodes,
|
|
rank,
|
|
world_size,
|
|
master_addr,
|
|
master_port,
|
|
num_update_per_episode,
|
|
num_recv_per_update,
|
|
batch_size,
|
|
model_config,
|
|
plugin_config,
|
|
microbatch_size=1,
|
|
num_generations=1,
|
|
gamma:float=1.0,
|
|
lam:float=0.95,
|
|
kl_coef:float=0.05,
|
|
use_wandb=True,
|
|
):
|
|
super().__init__(
|
|
num_producers,
|
|
num_episodes,
|
|
rank,
|
|
world_size,
|
|
master_addr,
|
|
master_port,
|
|
num_update_per_episode,
|
|
num_recv_per_update,
|
|
batch_size,
|
|
model_config,
|
|
plugin_config,
|
|
microbatch_size,
|
|
)
|
|
self.gamma = gamma
|
|
self.lam = lam
|
|
self.kl_coef = kl_coef
|
|
|
|
path = model_config.pop("path")
|
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
|
self.policy_model.train()
|
|
self.policy_model.gradient_checkpointing_enable()
|
|
self.critic_model = Critic(path, **model_config)
|
|
self.critic_model.model.gradient_checkpointing_enable()
|
|
self.critic_model.train()
|
|
|
|
# Disable dropout
|
|
disable_dropout(self.policy_model)
|
|
disable_dropout(self.critic_model)
|
|
|
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
|
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
|
|
|
|
self.accum_loss = torch.zeros(1, device=self.device)
|
|
self.accum_reward = torch.zeros(1, device=self.device)
|
|
self.accum_kl = torch.zeros(1, device=self.device)
|
|
self.accum_advantage = torch.zeros(1, device=self.device)
|
|
self.accum_critic_loss = torch.zeros(1, device=self.device)
|
|
self.accum_count = 0
|
|
|
|
# Reference model is initialized from policy model.
|
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
|
self.reference_model.eval()
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.pad_token_id = self.tokenizer.pad_token_id
|
|
self.num_generations = num_generations
|
|
|
|
# Initialize verifiable reward.
|
|
response_format_tags = {
|
|
"think_start": {"text": "<think>", "num_occur": 1},
|
|
"think_end": {"text": "</think>", "num_occur": 1},
|
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
}
|
|
self.reward_model = VerifiableReward(
|
|
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
|
)
|
|
|
|
self.policy_loss_fn = PolicyLoss()
|
|
self.critic_loss_fn = ValueLoss()
|
|
self.global_step = 0
|
|
if use_wandb and self.rank == 0:
|
|
self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)
|
|
|
|
def setup(self):
|
|
super().setup()
|
|
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
|
|
self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
|
|
self.critic_model, self.critic_optimizer
|
|
)
|
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
|
|
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
|
"""
|
|
Step data from policy model:
|
|
[{
|
|
"input_ids": torch.Tensor,
|
|
"attention_mask": torch.Tensor,
|
|
"action_mask": torch.Tensor,
|
|
"action_log_probs": torch.Tensor,
|
|
},
|
|
...]
|
|
Format:
|
|
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
|
"""
|
|
|
|
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
|
action_mask = data["action_mask"]
|
|
num_action = action_mask.shape[1]
|
|
old_action_log_probs = data["action_log_probs"].detach()
|
|
|
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
|
|
|
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
|
with ctx:
|
|
policy_model_logits = self.policy_model(
|
|
input_ids=data["input_ids"],
|
|
attention_mask=data["attention_mask"],
|
|
)["logits"]
|
|
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
|
|
|
with torch.no_grad():
|
|
reference_model_logits = self.reference_model(
|
|
input_ids=data["input_ids"],
|
|
attention_mask=data["attention_mask"],
|
|
)["logits"]
|
|
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
|
|
|
value = self.critic_model(
|
|
input_ids=data["input_ids"],
|
|
attention_mask=data["attention_mask"],
|
|
)
|
|
value = value[:, -num_action -1: -1] * action_mask
|
|
|
|
r = self.reward_model(
|
|
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
|
)
|
|
reward, kl = compute_reward_ppo(
|
|
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
|
)
|
|
|
|
# Calculate advantages
|
|
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
|
|
lastgaelam = 0
|
|
advantage_reversed = []
|
|
for t in reversed(range(num_action)):
|
|
nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
|
|
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
|
|
lastgaelam = delta + self.gamma * self.lam * lastgaelam
|
|
advantage_reversed.append(lastgaelam)
|
|
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
|
|
advantage = advantage.detach()
|
|
|
|
# KL divergence for logging
|
|
per_token_kl = (
|
|
torch.exp(reference_action_log_probs - action_log_probs)
|
|
- (reference_action_log_probs - action_log_probs)
|
|
- 1
|
|
)
|
|
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
|
|
|
# Calculate Loss
|
|
loss, skip_update, _ = self.policy_loss_fn(
|
|
action_log_probs,
|
|
old_action_log_probs,
|
|
advantage,
|
|
0, # kl is already included in the advantage
|
|
action_mask,
|
|
)
|
|
|
|
# Critic Loss
|
|
# Hack: use the current value to approximate the old value, should be old value mathematically
|
|
critic_loss = self.critic_loss_fn(
|
|
value,
|
|
value.detach(),
|
|
advantage,
|
|
action_mask=action_mask,
|
|
)
|
|
|
|
if not skip_update:
|
|
self.booster.backward(loss, self.optimizer)
|
|
self.critic_booster.backward(critic_loss, self.critic_optimizer)
|
|
loss = all_reduce_mean(loss, self.plugin)
|
|
critic_loss = all_reduce_mean(critic_loss, self.plugin)
|
|
r_mean = all_reduce_mean(r.mean(), self.plugin)
|
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
|
advantage = all_reduce_mean(advantage.mean(), self.plugin)
|
|
self.accum_loss.add_(loss.data)
|
|
self.accum_critic_loss.add_(critic_loss.data)
|
|
self.accum_advantage.add_(advantage.data)
|
|
self.accum_reward.add_(r_mean.data)
|
|
self.accum_kl.add_(kl.data)
|
|
self.accum_count += 1
|
|
if self.rank == 0:
|
|
print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
|
|
if need_update:
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
self.critic_optimizer.step()
|
|
self.critic_optimizer.zero_grad()
|
|
loss_scalar = self.accum_loss.item()
|
|
if self.rank == 0:
|
|
print(
|
|
"Loss:",
|
|
self.accum_loss.item() / self.accum_count,
|
|
"Reward:",
|
|
self.accum_reward.item() / self.accum_count,
|
|
"KL:",
|
|
self.accum_kl.item() / self.accum_count,
|
|
)
|
|
if self.global_step % 3 == 0:
|
|
for i in range(min(3, data["input_ids"].shape[0])):
|
|
response_decoded_for_logging = self.tokenizer.decode(
|
|
data["input_ids"][i], skip_special_tokens=True
|
|
)
|
|
response_reward_for_logging = r[i]
|
|
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
|
self.wandb_run.log(
|
|
{
|
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
|
"train/reward": self.accum_reward.item() / self.accum_count,
|
|
"train/kl": self.accum_kl.item() / self.accum_count,
|
|
"train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
|
|
"train/advantage": self.accum_advantage.item() / self.accum_count,
|
|
}
|
|
)
|
|
self.accum_loss.zero_()
|
|
self.accum_reward.zero_()
|
|
self.accum_kl.zero_()
|
|
self.accum_advantage.zero_()
|
|
self.accum_critic_loss.zero_()
|
|
self.accum_count = 0
|
|
self.global_step += 1
|
|
return loss_scalar
|
|
|
|
def state_dict(self):
|
|
self.policy_model._force_wait_all_gather()
|
|
model = self.policy_model.unwrap()
|
|
state_dict = model.state_dict()
|
|
return state_dict
|