diff --git a/applications/ChatGPT/chatgpt/models/base/actor.py b/applications/ChatGPT/chatgpt/models/base/actor.py index e2841dc68..57db2bb11 100644 --- a/applications/ChatGPT/chatgpt/models/base/actor.py +++ b/applications/ChatGPT/chatgpt/models/base/actor.py @@ -37,7 +37,7 @@ class Actor(LoRAModule): if pad_token_id is not None: attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) if not return_action_mask: - return sequences, attention_mask + return sequences, attention_mask, None input_len = input_ids.size(1) eos_token_id = kwargs.get('eos_token_id', None) if eos_token_id is None: diff --git a/applications/ChatGPT/chatgpt/models/base/critic.py b/applications/ChatGPT/chatgpt/models/base/critic.py index b12bddfcb..e68a743a7 100644 --- a/applications/ChatGPT/chatgpt/models/base/critic.py +++ b/applications/ChatGPT/chatgpt/models/base/critic.py @@ -18,15 +18,19 @@ class Critic(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - model: nn.Module, - value_head: nn.Module, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + model: nn.Module, + value_head: nn.Module, + lora_rank: int = 0, + lora_train_bias: str = 'none', + use_action_mask: bool = False, + ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head + self.use_action_mask = use_action_mask self.convert_to_lora() def forward(self, @@ -38,7 +42,7 @@ class Critic(LoRAModule): values = self.value_head(last_hidden_states).squeeze(-1) - if action_mask is not None: + if action_mask is not None and self.use_action_mask: num_actions = action_mask.size(1) prompt_mask = attention_mask[:, :-num_actions] values = values[:, :-num_actions] @@ -46,5 +50,5 @@ class Critic(LoRAModule): return value values = values[:, :-1] - value = values.mean(dim=1).squeeze(1) + value = values.mean(dim=1) return value diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py index 5a907309a..a32fb2e10 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py @@ -24,7 +24,8 @@ class BLOOMCritic(Critic): config: Optional[BloomConfig] = None, checkpoint: bool = False, lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + lora_train_bias: str = 'none', + **kwargs) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: @@ -34,4 +35,4 @@ class BLOOMCritic(Critic): if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py index 897ddb4ae..01e824386 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py @@ -20,7 +20,8 @@ class GPTCritic(Critic): def __init__(self, pretrained: Optional[str] = None, config: Optional[GPT2Config] = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + **kwargs) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: @@ -30,4 +31,4 @@ class GPTCritic(Critic): if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) - super().__init__(model, value_head) + super().__init__(model, value_head, **kwargs) diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_critic.py b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py index 767cecb79..1f5ead758 100644 --- a/applications/ChatGPT/chatgpt/models/opt/opt_critic.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py @@ -24,7 +24,8 @@ class OPTCritic(Critic): config: Optional[OPTConfig] = None, checkpoint: bool = False, lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + lora_train_bias: str = 'none', + **kargs) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: @@ -34,4 +35,4 @@ class OPTCritic(Critic): if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)