mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cifeat/ppo
parent
c8e13a9403
commit
6e096362ef
|
@ -13,8 +13,8 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
|
|
|
@ -90,13 +90,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
|
||||
|
||||
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
||||
|
||||
|
||||
data = {
|
||||
"input_ids": out.sequences,
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
return data
|
||||
|
||||
|
|
|
@ -7,11 +7,7 @@ from .grpo_consumer import GRPOConsumer
|
|||
from .ppo_consumer import PPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {
|
||||
"Simple": SimpleConsumer,
|
||||
"GRPO": GRPOConsumer,
|
||||
"PPO": PPOConsumer
|
||||
}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "PPO": PPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
|
|
|
@ -72,4 +72,4 @@ class ValueLoss(nn.Module):
|
|||
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
|
||||
else:
|
||||
loss = torch.mean(torch.max(surr1, surr2))
|
||||
return 0.5 * loss
|
||||
return 0.5 * loss
|
||||
|
|
|
@ -9,8 +9,8 @@ from coati.distributed.loss import PolicyLoss, ValueLoss
|
|||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.models import Critic, disable_dropout
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -33,9 +33,9 @@ class PPOConsumer(BaseConsumer):
|
|||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=1,
|
||||
gamma:float=1.0,
|
||||
lam:float=0.95,
|
||||
kl_coef:float=0.05,
|
||||
gamma: float = 1.0,
|
||||
lam: float = 0.95,
|
||||
kl_coef: float = 0.05,
|
||||
use_wandb=True,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -55,7 +55,7 @@ class PPOConsumer(BaseConsumer):
|
|||
self.gamma = gamma
|
||||
self.lam = lam
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
|
@ -63,14 +63,14 @@ class PPOConsumer(BaseConsumer):
|
|||
self.critic_model = Critic(path, **model_config)
|
||||
self.critic_model.model.gradient_checkpointing_enable()
|
||||
self.critic_model.train()
|
||||
|
||||
|
||||
# Disable dropout
|
||||
disable_dropout(self.policy_model)
|
||||
disable_dropout(self.critic_model)
|
||||
|
||||
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
||||
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
|
||||
|
||||
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
|
@ -152,15 +152,13 @@ class PPOConsumer(BaseConsumer):
|
|||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)
|
||||
value = value[:, -num_action -1: -1] * action_mask
|
||||
value = value[:, -num_action - 1 : -1] * action_mask
|
||||
|
||||
r = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
r = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
|
||||
reward, kl = compute_reward_ppo(
|
||||
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
||||
)
|
||||
|
||||
|
||||
# Calculate advantages
|
||||
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
|
||||
lastgaelam = 0
|
||||
|
@ -172,7 +170,7 @@ class PPOConsumer(BaseConsumer):
|
|||
advantage_reversed.append(lastgaelam)
|
||||
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
|
||||
advantage = advantage.detach()
|
||||
|
||||
|
||||
# KL divergence for logging
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
|
@ -189,7 +187,7 @@ class PPOConsumer(BaseConsumer):
|
|||
0, # kl is already included in the advantage
|
||||
action_mask,
|
||||
)
|
||||
|
||||
|
||||
# Critic Loss
|
||||
# Hack: use the current value to approximate the old value, should be old value mathematically
|
||||
critic_loss = self.critic_loss_fn(
|
||||
|
@ -236,7 +234,9 @@ class PPOConsumer(BaseConsumer):
|
|||
data["input_ids"][i], skip_special_tokens=True
|
||||
)
|
||||
response_reward_for_logging = r[i]
|
||||
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
||||
print(
|
||||
f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}"
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
|
|
|
@ -94,11 +94,11 @@ class BaseProducer:
|
|||
print(
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
for episode in range(self.num_episodes):
|
||||
self.dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.dataloader):
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
break
|
||||
outputs = self.rollout(**batch)
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
|
|
|
@ -100,6 +100,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def compute_reward_ppo(
|
||||
r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
|
@ -125,4 +126,4 @@ def compute_reward_ppo(
|
|||
assert action_mask[i].sum() > 0
|
||||
reward[i, : action_mask[i].sum()] += r_clip[i]
|
||||
reward[i, action_mask[i].sum() :] *= 0
|
||||
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
|
||||
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
|
||||
|
|
|
@ -28,7 +28,7 @@ if __name__ == "__main__":
|
|||
top_p=0.8,
|
||||
)
|
||||
|
||||
if args.backend == "transformers":
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
use_flash_attention_2=True,
|
||||
|
@ -43,13 +43,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=768,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"]
|
||||
)
|
||||
dict(max_length=768, do_sample=True, max_new_tokens=None, early_stopping=False, stop_strings=["</answer>"])
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
|
|
Loading…
Reference in New Issue