[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
feat/ppo
pre-commit-ci[bot] 2025-03-07 10:43:01 +00:00
parent c8e13a9403
commit 6e096362ef
8 changed files with 27 additions and 36 deletions

View File

@ -13,8 +13,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch

View File

@ -96,7 +96,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx
"response_idx": response_idx,
}
return data

View File

@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
from .ppo_consumer import PPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"PPO": PPOConsumer
}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
def get_jsonl_size_fast(path: str) -> int:

View File

@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam
@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer):
plugin_config,
microbatch_size=1,
num_generations=1,
gamma:float=1.0,
lam:float=0.95,
kl_coef:float=0.05,
gamma: float = 1.0,
lam: float = 0.95,
kl_coef: float = 0.05,
use_wandb=True,
):
super().__init__(
@ -152,11 +152,9 @@ class PPOConsumer(BaseConsumer):
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)
value = value[:, -num_action -1: -1] * action_mask
value = value[:, -num_action - 1 : -1] * action_mask
r = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
)
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
data["input_ids"][i], skip_special_tokens=True
)
response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
print(
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
)
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,

View File

@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mean = tensor / (mask_sum + 1e-8)
return mean
def compute_reward_ppo(
r: Union[torch.Tensor, float],
kl_coef: float,

View File

@ -43,13 +43,7 @@ if __name__ == "__main__":
)
)
generate_config.update(
dict(
max_length=768,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"]
)
dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
)
elif args.backend == "vllm":
inference_model_config.update(