mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cifeat/ppo
parent
c8e13a9403
commit
6e096362ef
|
@ -13,8 +13,8 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
|
|
|
@ -96,7 +96,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
return data
|
||||
|
||||
|
|
|
@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
|
|||
from .ppo_consumer import PPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {
|
||||
"Simple": SimpleConsumer,
|
||||
"GRPO": GRPOConsumer,
|
||||
"PPO": PPOConsumer
|
||||
}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
|
|
|
@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
|
|||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.models import Critic, disable_dropout
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer):
|
|||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=1,
|
||||
gamma:float=1.0,
|
||||
lam:float=0.95,
|
||||
kl_coef:float=0.05,
|
||||
gamma: float = 1.0,
|
||||
lam: float = 0.95,
|
||||
kl_coef: float = 0.05,
|
||||
use_wandb=True,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -152,11 +152,9 @@ class PPOConsumer(BaseConsumer):
|
|||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)
|
||||
value = value[:, -num_action -1: -1] * action_mask
|
||||
value = value[:, -num_action - 1 : -1] * action_mask
|
||||
|
||||
r = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
|
||||
reward, kl = compute_reward_ppo(
|
||||
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
||||
)
|
||||
|
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
|
|||
data["input_ids"][i], skip_special_tokens=True
|
||||
)
|
||||
response_reward_for_logging = r[i]
|
||||
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
||||
print(
|
||||
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
|
|
|
@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def compute_reward_ppo(
|
||||
r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
|
|
|
@ -43,13 +43,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=768,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"]
|
||||
)
|
||||
dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
|
|
Loading…
Reference in New Issue