Browse Source

[chat]: add vf_coef argument for PPOTrainer (#3318)

pull/3536/head
zhang-yi-chi 2 years ago committed by GitHub
parent
commit
e6a132a449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      applications/Chat/coati/models/loss.py
  2. 4
      applications/Chat/coati/trainer/ppo.py

2
applications/Chat/coati/models/loss.py

@ -65,7 +65,7 @@ class ValueLoss(nn.Module):
surr2 = (values - reward)**2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return loss
return 0.5 * loss
class PPOPtxActorLoss(nn.Module):

4
applications/Chat/coati/trainer/ppo.py

@ -32,6 +32,7 @@ class PPOTrainer(Trainer):
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
@ -56,6 +57,7 @@ class PPOTrainer(Trainer):
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
experience_batch_size: int = 8,
max_epochs: int = 1,
@ -74,6 +76,7 @@ class PPOTrainer(Trainer):
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
@ -112,6 +115,7 @@ class PPOTrainer(Trainer):
experience.values,
experience.reward,
action_mask=experience.action_mask)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()

Loading…
Cancel
Save