[chatgpt] change critic input as state (#3042)

* fix Critic

* fix Critic

* fix Critic

* fix neglect of attention mask

* fix neglect of attention mask

* fix neglect of attention mask

* add return

---------

Co-authored-by: yangwenjun <yangwenjun@soyoung.com>
Co-authored-by: yangwjd <yangwjd@chanjet.com>
pull/3056/head
wenjunyang 2023-03-08 15:18:02 +08:00 committed by GitHub
parent 2ef855c798
commit b51bfec357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 3 deletions

View File

@ -36,12 +36,15 @@ class Critic(LoRAModule):
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state'] last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1] values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None: if action_mask is not None:
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
values = values[:, -num_actions:] prompt_mask = attention_mask[:, :-num_actions]
value = masked_mean(values, action_mask, dim=1) values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value return value
values = values[:, :-1]
value = values.mean(dim=1).squeeze(1) value = values.mean(dim=1).squeeze(1)
return value return value