feat/ppo
YeAnbang 2025-03-07 18:29:34 +08:00
parent eb6337f07f
commit 6a6634b6e8
10 changed files with 350 additions and 12 deletions

1
.gitignore vendored
View File

@ -162,4 +162,5 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/wandb
applications/ColossalChat/tests/logs

View File

@ -14,6 +14,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
@ -76,6 +77,10 @@ class BaseConsumer:
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
if hasattr(self, "critic_model"):
plugin_config.update({"custom_policy": get_autopolicy(self.critic_model.model)})
self.critic_plugin = HybridParallelPlugin(**plugin_config)
self.critic_booster = Booster(plugin=self.critic_plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)

View File

@ -60,6 +60,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config["tokenizer"] = tokenizer
self.tokenizer = tokenizer
@torch.no_grad()
@ -76,21 +77,26 @@ class TransformersInferenceBackend(BaseInferenceBackend):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
response_idx[:, 0] = input_len
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = {
"input_ids": out.sequences,
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx
}
return data
@ -154,7 +160,6 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=4,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
@ -167,7 +172,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
self.num_generations = generate_config["n"]
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

View File

@ -4,11 +4,13 @@ import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .ppo_consumer import PPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"PPO": PPOConsumer
}

View File

@ -45,3 +45,31 @@ class PolicyLoss(nn.Module):
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
advantage: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
returns = advantage + old_values
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
if action_mask is not None:
# loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss

View File

@ -0,0 +1,262 @@
from contextlib import nullcontext
from typing import Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ray.remote
class PPOConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=1,
gamma:float=1.0,
lam:float=0.95,
kl_coef:float=0.05,
use_wandb=True,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
self.gamma = gamma
self.lam = lam
self.kl_coef = kl_coef
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.critic_model = Critic(path, **model_config)
self.critic_model.model.gradient_checkpointing_enable()
self.critic_model.train()
# Disable dropout
disable_dropout(self.policy_model)
disable_dropout(self.critic_model)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_advantage = torch.zeros(1, device=self.device)
self.accum_critic_loss = torch.zeros(1, device=self.device)
self.accum_count = 0
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.policy_loss_fn = PolicyLoss()
self.critic_loss_fn = ValueLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)
def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
self.critic_model, self.critic_optimizer
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
"input_ids": torch.Tensor,
"attention_mask": torch.Tensor,
"action_mask": torch.Tensor,
"action_log_probs": torch.Tensor,
},
...]
Format:
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"].detach()
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
value = self.critic_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)
value = value[:, -num_action -1: -1] * action_mask
r = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
)
# Calculate advantages
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
lastgaelam = 0
advantage_reversed = []
for t in reversed(range(num_action)):
nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantage_reversed.append(lastgaelam)
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
advantage = advantage.detach()
# KL divergence for logging
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantage,
0, # kl is already included in the advantage
action_mask,
)
# Critic Loss
# Hack: use the current value to approximate the old value, should be old value mathematically
critic_loss = self.critic_loss_fn(
value,
value.detach(),
advantage,
action_mask=action_mask,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
self.critic_booster.backward(critic_loss, self.critic_optimizer)
loss = all_reduce_mean(loss, self.plugin)
critic_loss = all_reduce_mean(critic_loss, self.plugin)
r_mean = all_reduce_mean(r.mean(), self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
advantage = all_reduce_mean(advantage.mean(), self.plugin)
self.accum_loss.add_(loss.data)
self.accum_critic_loss.add_(critic_loss.data)
self.accum_advantage.add_(advantage.data)
self.accum_reward.add_(r_mean.data)
self.accum_kl.add_(kl.data)
self.accum_count += 1
if self.rank == 0:
print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
if self.rank == 0:
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"Reward:",
self.accum_reward.item() / self.accum_count,
"KL:",
self.accum_kl.item() / self.accum_count,
)
if self.global_step % 3 == 0:
for i in range(min(3, data["input_ids"].shape[0])):
response_decoded_for_logging = self.tokenizer.decode(
data["input_ids"][i], skip_special_tokens=True
)
response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
"train/advantage": self.accum_advantage.item() / self.accum_count,
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
self.accum_advantage.zero_()
self.accum_critic_loss.zero_()
self.accum_count = 0
self.global_step += 1
return loss_scalar
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -154,6 +154,11 @@ class SimpleProducer(BaseProducer):
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
if self.backend_cls.__name__ == "TransformersInferenceBackend":
gt_answer = kwargs.pop("gt_answer")
out = self.model.generate(input_ids, attention_mask, **kwargs)
out["gt_answer"] = gt_answer.to(out["input_ids"].device)
return out
return self.model.generate(input_ids, attention_mask, **kwargs)
def load_state_dict(self, state_dict):

View File

@ -19,8 +19,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 2.0
# if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
# reward = reward + 2.0
return reward

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Union
import torch
@ -99,3 +99,30 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def compute_reward_ppo(
r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
reward_eps=5,
) -> torch.Tensor:
"""
Args:
log_probs: [batch_size, response_length]
log_probs_base: [batch_size, response_length]
action_mask: [batch_size, response_length]
r: float
Returns:
reward: [batch_size, response_length]
"""
log_ratio = log_probs - log_probs_base # address numerical instability issue
kl = -kl_coef * log_ratio * action_mask
reward = kl
r_clip = torch.clamp(r, -reward_eps, reward_eps)
for i in range(action_mask.size(0)):
assert action_mask[i].sum() > 0
reward[i, : action_mask[i].sum()] += r_clip[i]
reward[i, action_mask[i].sum() :] *= 0
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask

View File

@ -7,6 +7,7 @@ from coati.distributed.launch import launch_distributed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-rm", "--reward_model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
@ -15,7 +16,7 @@ if __name__ == "__main__":
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GPRO", "PPO"])
args = parser.parse_args()
ray.init(address="local", namespace="ray-example")
@ -27,26 +28,27 @@ if __name__ == "__main__":
top_p=0.8,
)
if args.backend == "transformers":
if args.backend == "transformers":
inference_model_config.update(
dict(
attn_implementation="flash_attention_2",
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
attn_implementation="flash_attention_2",
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
use_cache=False,
)
)
generate_config.update(
dict(
max_length=512,
max_length=768,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"]
)
)
elif args.backend == "vllm":
@ -57,12 +59,13 @@ if __name__ == "__main__":
)
generate_config.update(
dict(
max_tokens=2048,
max_tokens=512,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
temperature=0.7,
temperature=0.5,
top_p=0.95,
n=1,
)
)
else: