From b51bfec3573e2d217a8ab4f314cf891a53e18e19 Mon Sep 17 00:00:00 2001 From: wenjunyang <wendaleyang@gmail.com> Date: Wed, 8 Mar 2023 15:18:02 +0800 Subject: [PATCH] [chatgpt] change critic input as state (#3042) * fix Critic * fix Critic * fix Critic * fix neglect of attention mask * fix neglect of attention mask * fix neglect of attention mask * add return --------- Co-authored-by: yangwenjun <yangwenjun@soyoung.com> Co-authored-by: yangwjd <yangwjd@chanjet.com> --- applications/ChatGPT/chatgpt/models/base/critic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/applications/ChatGPT/chatgpt/models/base/critic.py b/applications/ChatGPT/chatgpt/models/base/critic.py index 4bff5ee97..b12bddfcb 100644 --- a/applications/ChatGPT/chatgpt/models/base/critic.py +++ b/applications/ChatGPT/chatgpt/models/base/critic.py @@ -36,12 +36,15 @@ class Critic(LoRAModule): outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs['last_hidden_state'] - values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1] + values = self.value_head(last_hidden_states).squeeze(-1) if action_mask is not None: num_actions = action_mask.size(1) - values = values[:, -num_actions:] - value = masked_mean(values, action_mask, dim=1) + prompt_mask = attention_mask[:, :-num_actions] + values = values[:, :-num_actions] + value = masked_mean(values, prompt_mask, dim=1) return value + + values = values[:, :-1] value = values.mean(dim=1).squeeze(1) return value