mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
7.4 KiB
181 lines
7.4 KiB
8 months ago
|
"""
|
||
|
experience maker.
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from coati.dataset.utils import find_first_occurrence_subsequence
|
||
|
from coati.models import Critic, RewardModel
|
||
|
from coati.models.generation import generate
|
||
|
from coati.models.utils import calc_action_log_probs, compute_reward
|
||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||
|
|
||
|
from colossalai.logging import get_dist_logger
|
||
|
|
||
|
from .base import Experience, ExperienceMaker
|
||
|
|
||
|
logger = get_dist_logger()
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
|
||
|
def is_rank_0() -> bool:
|
||
|
return not dist.is_initialized() or dist.get_rank() == 0
|
||
|
|
||
|
|
||
|
class NaiveExperienceMaker(ExperienceMaker):
|
||
|
"""
|
||
|
Naive experience maker.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
actor: PreTrainedModel,
|
||
|
critic: Critic,
|
||
|
reward_model: RewardModel,
|
||
|
initial_model: PreTrainedModel,
|
||
|
tokenizer: PreTrainedTokenizer,
|
||
|
kl_coef: float = 0.01,
|
||
|
gamma: float = 1.0,
|
||
|
lam: float = 0.95,
|
||
|
) -> None:
|
||
|
super().__init__(actor, critic, reward_model, initial_model)
|
||
|
self.tokenizer = tokenizer
|
||
|
self.kl_coef = kl_coef
|
||
|
self.gamma = gamma
|
||
|
self.lam = lam
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
|
||
|
"""
|
||
|
Calculates the advantage values for each action based on the value and reward tensors.
|
||
|
|
||
|
Args:
|
||
|
value (torch.Tensor): Tensor containing the predicted values from critic.
|
||
|
reward (torch.Tensor): reward of the shape [B, len].
|
||
|
num_actions (int): Number of actions.
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: Tensor containing the calculated advantages for each action.
|
||
|
"""
|
||
|
lastgaelam = 0
|
||
|
advantages_reversed = []
|
||
|
for t in reversed(range(num_actions)):
|
||
|
nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0
|
||
|
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
|
||
|
lastgaelam = delta + self.gamma * self.lam * lastgaelam
|
||
|
advantages_reversed.append(lastgaelam)
|
||
|
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||
|
return advantages
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
|
||
|
"""
|
||
|
Generates an experience using the given input_ids and attention_mask.
|
||
|
|
||
|
Args:
|
||
|
input_ids (torch.Tensor): The input tensor containing the tokenized input sequence.
|
||
|
attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to.
|
||
|
**generate_kwargs: Additional keyword arguments for the generation process.
|
||
|
|
||
|
Returns:
|
||
|
Experience: The generated experience object.
|
||
|
|
||
|
"""
|
||
|
self.actor.eval()
|
||
|
self.critic.eval()
|
||
|
self.initial_model.eval()
|
||
|
self.reward_model.eval()
|
||
|
pad_token_id = self.tokenizer.pad_token_id
|
||
|
|
||
|
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
|
||
|
torch.manual_seed(41) # for tp, gurantee the same input for reward model
|
||
|
|
||
|
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
|
||
|
|
||
|
# Pad to max length
|
||
|
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
|
||
|
sequence_length = sequences.size(1)
|
||
|
|
||
|
# Calculate auxiliary tensors
|
||
|
attention_mask = None
|
||
|
if pad_token_id is not None:
|
||
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||
|
|
||
|
input_len = input_ids.size(1)
|
||
|
if stop_token_ids is None:
|
||
|
# End the sequence with eos token
|
||
|
eos_token_id = self.tokenizer.eos_token_id
|
||
|
if eos_token_id is None:
|
||
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||
|
else:
|
||
|
# Left padding may be applied, only mask action
|
||
|
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||
|
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||
|
else:
|
||
|
# stop_token_ids are given, generation ends with stop_token_ids
|
||
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||
|
for i in range(sequences.size(0)):
|
||
|
stop_index = find_first_occurrence_subsequence(
|
||
|
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
|
||
|
)
|
||
|
if stop_index == -1:
|
||
|
# Sequence does not contain stop_token_ids, this should never happen BTW
|
||
|
logger.warning(
|
||
|
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
|
||
|
)
|
||
|
else:
|
||
|
# Keep stop tokens
|
||
|
stop_index = input_len + stop_index
|
||
|
action_mask[i, stop_index + len(stop_token_ids) :] = False
|
||
|
|
||
|
generation_end_index = (action_mask == True).sum(dim=-1) - 1
|
||
|
action_mask[:, :input_len] = False
|
||
|
action_mask = action_mask[:, 1:]
|
||
|
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||
|
num_actions = action_mask.size(1)
|
||
|
|
||
|
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
|
||
|
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||
|
|
||
|
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
|
||
|
|
||
|
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||
|
|
||
|
# Convert to right padding for the reward model and the critic model
|
||
|
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
|
||
|
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
|
||
|
for i in range(sequences.size(0)):
|
||
|
sequence = sequences[i]
|
||
|
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
|
||
|
eos_index = generation_end_index[i]
|
||
|
sequence_to_pad = sequence[bos_index:eos_index]
|
||
|
sequence_padded = F.pad(
|
||
|
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
|
||
|
)
|
||
|
input_ids_rm[i] = sequence_padded
|
||
|
if sequence_length - sequence_to_pad.size(0) > 0:
|
||
|
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
|
||
|
else:
|
||
|
attention_mask_rm[i, :] = 1
|
||
|
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
|
||
|
|
||
|
r = self.reward_model(
|
||
|
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
||
|
attention_mask=attention_mask_rm.to(device=sequences.device),
|
||
|
)
|
||
|
|
||
|
value = self.critic(
|
||
|
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
||
|
attention_mask=attention_mask_rm.to(device=sequences.device),
|
||
|
)
|
||
|
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||
|
value = value[:, -num_actions:] * action_mask
|
||
|
advantages = self.calculate_advantage(value, reward, num_actions)
|
||
|
|
||
|
advantages = advantages.detach()
|
||
|
value = value.detach()
|
||
|
r = r.detach()
|
||
|
|
||
|
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
|