[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
feat/ppo
pre-commit-ci[bot] 2025-03-07 10:43:01 +00:00
parent c8e13a9403
commit 6e096362ef
8 changed files with 27 additions and 36 deletions

View File

@ -13,8 +13,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch

View File

@ -90,13 +90,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
data = {
"input_ids": out.sequences,
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx
"response_idx": response_idx,
}
return data

View File

@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
from .ppo_consumer import PPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"PPO": PPOConsumer
}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
def get_jsonl_size_fast(path: str) -> int:

View File

@ -72,4 +72,4 @@ class ValueLoss(nn.Module):
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss
return 0.5 * loss

View File

@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer):
plugin_config,
microbatch_size=1,
num_generations=1,
gamma:float=1.0,
lam:float=0.95,
kl_coef:float=0.05,
gamma: float = 1.0,
lam: float = 0.95,
kl_coef: float = 0.05,
use_wandb=True,
):
super().__init__(
@ -55,7 +55,7 @@ class PPOConsumer(BaseConsumer):
self.gamma = gamma
self.lam = lam
self.kl_coef = kl_coef
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
@ -63,14 +63,14 @@ class PPOConsumer(BaseConsumer):
self.critic_model = Critic(path, **model_config)
self.critic_model.model.gradient_checkpointing_enable()
self.critic_model.train()
# Disable dropout
disable_dropout(self.policy_model)
disable_dropout(self.critic_model)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
@ -152,15 +152,13 @@ class PPOConsumer(BaseConsumer):
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)
value = value[:, -num_action -1: -1] * action_mask
value = value[:, -num_action - 1 : -1] * action_mask
r = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
)
# Calculate advantages
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
lastgaelam = 0
@ -172,7 +170,7 @@ class PPOConsumer(BaseConsumer):
advantage_reversed.append(lastgaelam)
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
advantage = advantage.detach()
# KL divergence for logging
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
@ -189,7 +187,7 @@ class PPOConsumer(BaseConsumer):
0, # kl is already included in the advantage
action_mask,
)
# Critic Loss
# Hack: use the current value to approximate the old value, should be old value mathematically
critic_loss = self.critic_loss_fn(
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
data["input_ids"][i], skip_special_tokens=True
)
response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
print(
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
)
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,

View File

@ -94,11 +94,11 @@ class BaseProducer:
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
)
for episode in range(self.num_episodes):
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
if i >= num_valid_microbatches:
break
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)

View File

@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mean = tensor / (mask_sum + 1e-8)
return mean
def compute_reward_ppo(
r: Union[torch.Tensor, float],
kl_coef: float,
@ -125,4 +126,4 @@ def compute_reward_ppo(
assert action_mask[i].sum() > 0
reward[i, : action_mask[i].sum()] += r_clip[i]
reward[i, action_mask[i].sum() :] *= 0
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
top_p=0.8,
)
if args.backend == "transformers":
if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
@ -43,13 +43,7 @@ if __name__ == "__main__":
)
)
generate_config.update(
dict(
max_length=768,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"]
)
dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
)
elif args.backend == "vllm":
inference_model_config.update(