mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix compute_approx_kl (#4338)
parent
03654c0ce2
commit
75c5389037
|
@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
|
|||
action_mask: Mask for actions.
|
||||
"""
|
||||
|
||||
log_ratio = log_probs - log_probs_base
|
||||
log_ratio = log_probs_base - log_probs
|
||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||
if action_mask is not None:
|
||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||
|
|
Loading…
Reference in New Issue