[chatgpt] change critic input as state (#3042)

* fix Critic

* fix Critic

* fix Critic

* fix neglect of attention mask

* fix neglect of attention mask

* fix neglect of attention mask

* add return

---------

Co-authored-by: yangwenjun <yangwenjun@soyoung.com>
Co-authored-by: yangwjd <yangwjd@chanjet.com>
pull/3056/head
wenjunyang 2023-03-08 15:18:02 +08:00 committed by GitHub
parent 2ef855c798
commit b51bfec357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 3 deletions

View File

@ -36,12 +36,15 @@ class Critic(LoRAModule):
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None:
num_actions = action_mask.size(1)
values = values[:, -num_actions:]
value = masked_mean(values, action_mask, dim=1)
prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value
values = values[:, :-1]
value = values.mean(dim=1).squeeze(1)
return value