[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
feat/ppo
pre-commit-ci[bot] 2025-03-07 10:43:01 +00:00
parent c8e13a9403
commit 6e096362ef
8 changed files with 27 additions and 36 deletions

View File

@ -13,8 +13,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch from .utils import bind_batch, post_recv, unbind_batch

View File

@ -96,7 +96,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"action_log_probs": action_log_probs, "action_log_probs": action_log_probs,
"action_mask": action_mask, "action_mask": action_mask,
"response_idx": response_idx "response_idx": response_idx,
} }
return data return data

View File

@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
from .ppo_consumer import PPOConsumer from .ppo_consumer import PPOConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = { ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"PPO": PPOConsumer
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:

View File

@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
from coati.trainer.utils import all_reduce_mean
from coati.models import Critic, disable_dropout from coati.models import Critic, disable_dropout
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -154,9 +154,7 @@ class PPOConsumer(BaseConsumer):
) )
value = value[:, -num_action - 1 : -1] * action_mask value = value[:, -num_action - 1 : -1] * action_mask
r = self.reward_model( r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward, kl = compute_reward_ppo( reward, kl = compute_reward_ppo(
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
) )
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
data["input_ids"][i], skip_special_tokens=True data["input_ids"][i], skip_special_tokens=True
) )
response_reward_for_logging = r[i] response_reward_for_logging = r[i]
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}") print(
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
)
self.wandb_run.log( self.wandb_run.log(
{ {
"train/loss": self.accum_loss.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,

View File

@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mean = tensor / (mask_sum + 1e-8) mean = tensor / (mask_sum + 1e-8)
return mean return mean
def compute_reward_ppo( def compute_reward_ppo(
r: Union[torch.Tensor, float], r: Union[torch.Tensor, float],
kl_coef: float, kl_coef: float,

View File

@ -43,13 +43,7 @@ if __name__ == "__main__":
) )
) )
generate_config.update( generate_config.update(
dict( dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
max_length=768,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"]
)
) )
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(