mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cifeat/ppo
parent
c8e13a9403
commit
6e096362ef
|
@ -13,8 +13,8 @@ from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import bind_batch, post_recv, unbind_batch
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"action_log_probs": action_log_probs,
|
"action_log_probs": action_log_probs,
|
||||||
"action_mask": action_mask,
|
"action_mask": action_mask,
|
||||||
"response_idx": response_idx
|
"response_idx": response_idx,
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
|
||||||
from .ppo_consumer import PPOConsumer
|
from .ppo_consumer import PPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
|
||||||
"Simple": SimpleConsumer,
|
|
||||||
"GRPO": GRPOConsumer,
|
|
||||||
"PPO": PPOConsumer
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
|
|
|
@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
|
||||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
||||||
from coati.trainer.utils import all_reduce_mean
|
|
||||||
from coati.models import Critic, disable_dropout
|
from coati.models import Critic, disable_dropout
|
||||||
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
@ -154,9 +154,7 @@ class PPOConsumer(BaseConsumer):
|
||||||
)
|
)
|
||||||
value = value[:, -num_action - 1 : -1] * action_mask
|
value = value[:, -num_action - 1 : -1] * action_mask
|
||||||
|
|
||||||
r = self.reward_model(
|
r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
|
||||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
|
||||||
)
|
|
||||||
reward, kl = compute_reward_ppo(
|
reward, kl = compute_reward_ppo(
|
||||||
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
||||||
)
|
)
|
||||||
|
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
|
||||||
data["input_ids"][i], skip_special_tokens=True
|
data["input_ids"][i], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
response_reward_for_logging = r[i]
|
response_reward_for_logging = r[i]
|
||||||
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
print(
|
||||||
|
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
|
||||||
|
)
|
||||||
self.wandb_run.log(
|
self.wandb_run.log(
|
||||||
{
|
{
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
|
|
|
@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||||
mean = tensor / (mask_sum + 1e-8)
|
mean = tensor / (mask_sum + 1e-8)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
|
||||||
def compute_reward_ppo(
|
def compute_reward_ppo(
|
||||||
r: Union[torch.Tensor, float],
|
r: Union[torch.Tensor, float],
|
||||||
kl_coef: float,
|
kl_coef: float,
|
||||||
|
|
|
@ -43,13 +43,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
|
||||||
max_length=768,
|
|
||||||
do_sample=True,
|
|
||||||
max_new_tokens=None,
|
|
||||||
early_stopping=False,
|
|
||||||
stop_strings=["</answer>"]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
|
|
Loading…
Reference in New Issue