Browse Source

[chat] fix compute_approx_kl (#4338)

pull/4359/head
Wenhao Chen 1 year ago committed by GitHub
parent
commit
75c5389037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      applications/Chat/coati/models/utils.py

2
applications/Chat/coati/models/utils.py

@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
action_mask: Mask for actions.
"""
log_ratio = log_probs - log_probs_base
log_ratio = log_probs_base - log_probs
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)

Loading…
Cancel
Save