mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] change critic input as state (#3042)
* fix Critic * fix Critic * fix Critic * fix neglect of attention mask * fix neglect of attention mask * fix neglect of attention mask * add return --------- Co-authored-by: yangwenjun <yangwenjun@soyoung.com> Co-authored-by: yangwjd <yangwjd@chanjet.com>pull/3056/head
parent
2ef855c798
commit
b51bfec357
|
@ -36,12 +36,15 @@ class Critic(LoRAModule):
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
last_hidden_states = outputs['last_hidden_state']
|
||||||
|
|
||||||
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1]
|
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||||
|
|
||||||
if action_mask is not None:
|
if action_mask is not None:
|
||||||
num_actions = action_mask.size(1)
|
num_actions = action_mask.size(1)
|
||||||
values = values[:, -num_actions:]
|
prompt_mask = attention_mask[:, :-num_actions]
|
||||||
value = masked_mean(values, action_mask, dim=1)
|
values = values[:, :-num_actions]
|
||||||
|
value = masked_mean(values, prompt_mask, dim=1)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
values = values[:, :-1]
|
||||||
value = values.mean(dim=1).squeeze(1)
|
value = values.mean(dim=1).squeeze(1)
|
||||||
return value
|
return value
|
||||||
|
|
Loading…
Reference in New Issue