mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix compute_approx_kl (#4338)
parent
03654c0ce2
commit
75c5389037
|
@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
|
||||||
action_mask: Mask for actions.
|
action_mask: Mask for actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
log_ratio = log_probs - log_probs_base
|
log_ratio = log_probs_base - log_probs
|
||||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||||
if action_mask is not None:
|
if action_mask is not None:
|
||||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||||
|
|
Loading…
Reference in New Issue