From b51bfec3573e2d217a8ab4f314cf891a53e18e19 Mon Sep 17 00:00:00 2001
From: wenjunyang <wendaleyang@gmail.com>
Date: Wed, 8 Mar 2023 15:18:02 +0800
Subject: [PATCH] [chatgpt] change critic input as state (#3042)

* fix Critic

* fix Critic

* fix Critic

* fix neglect of attention mask

* fix neglect of attention mask

* fix neglect of attention mask

* add return

---------

Co-authored-by: yangwenjun <yangwenjun@soyoung.com>
Co-authored-by: yangwjd <yangwjd@chanjet.com>
---
 applications/ChatGPT/chatgpt/models/base/critic.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/applications/ChatGPT/chatgpt/models/base/critic.py b/applications/ChatGPT/chatgpt/models/base/critic.py
index 4bff5ee97..b12bddfcb 100644
--- a/applications/ChatGPT/chatgpt/models/base/critic.py
+++ b/applications/ChatGPT/chatgpt/models/base/critic.py
@@ -36,12 +36,15 @@ class Critic(LoRAModule):
         outputs = self.model(sequences, attention_mask=attention_mask)
         last_hidden_states = outputs['last_hidden_state']
 
-        values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
+        values = self.value_head(last_hidden_states).squeeze(-1)
 
         if action_mask is not None:
             num_actions = action_mask.size(1)
-            values = values[:, -num_actions:]
-            value = masked_mean(values, action_mask, dim=1)
+            prompt_mask = attention_mask[:, :-num_actions]
+            values = values[:, :-num_actions]
+            value = masked_mean(values, prompt_mask, dim=1)
             return value
+
+        values = values[:, :-1]
         value = values.mean(dim=1).squeeze(1)
         return value