Browse Source

[chatgpt] change critic input as state (#3042)

* fix Critic

* fix Critic

* fix Critic

* fix neglect of attention mask

* fix neglect of attention mask

* fix neglect of attention mask

* add return

---------

Co-authored-by: yangwenjun <yangwenjun@soyoung.com>
Co-authored-by: yangwjd <yangwjd@chanjet.com>
pull/3056/head
wenjunyang 2 years ago committed by GitHub
parent
commit
b51bfec357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      applications/ChatGPT/chatgpt/models/base/critic.py

9
applications/ChatGPT/chatgpt/models/base/critic.py

@ -36,12 +36,15 @@ class Critic(LoRAModule):
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None:
num_actions = action_mask.size(1)
values = values[:, -num_actions:]
value = masked_mean(values, action_mask, dim=1)
prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value
values = values[:, :-1]
value = values.mean(dim=1).squeeze(1)
return value

Loading…
Cancel
Save