mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt]add flag of action mask in critic(#3086)
parent
95a36eae63
commit
02ae80bf9c
|
@ -37,7 +37,7 @@ class Actor(LoRAModule):
|
|||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
|
|
|
@ -18,15 +18,19 @@ class Critic(LoRAModule):
|
|||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
use_action_mask: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self,
|
||||
|
@ -38,7 +42,7 @@ class Critic(LoRAModule):
|
|||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None:
|
||||
if action_mask is not None and self.use_action_mask:
|
||||
num_actions = action_mask.size(1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
|
@ -46,5 +50,5 @@ class Critic(LoRAModule):
|
|||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1)
|
||||
value = values.mean(dim=1)
|
||||
return value
|
||||
|
|
|
@ -24,7 +24,8 @@ class BLOOMCritic(Critic):
|
|||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -34,4 +35,4 @@ class BLOOMCritic(Critic):
|
|||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
|
@ -20,7 +20,8 @@ class GPTCritic(Critic):
|
|||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -30,4 +31,4 @@ class GPTCritic(Critic):
|
|||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head)
|
||||
super().__init__(model, value_head, **kwargs)
|
||||
|
|
|
@ -24,7 +24,8 @@ class OPTCritic(Critic):
|
|||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
lora_train_bias: str = 'none',
|
||||
**kargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -34,4 +35,4 @@ class OPTCritic(Critic):
|
|||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue