You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalChat/coati/experience_maker/naive.py

181 lines
7.4 KiB

"""
experience maker.
"""
import torch
import torch.nn.functional as F
from coati.dataset.utils import find_first_occurrence_subsequence
from coati.models import Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedModel, PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .base import Experience, ExperienceMaker
logger = get_dist_logger()
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
class NaiveExperienceMaker(ExperienceMaker):
"""
Naive experience maker.
"""
def __init__(
self,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
initial_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.01,
gamma: float = 1.0,
lam: float = 0.95,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.gamma = gamma
self.lam = lam
@torch.no_grad()
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
"""
Calculates the advantage values for each action based on the value and reward tensors.
Args:
value (torch.Tensor): Tensor containing the predicted values from critic.
reward (torch.Tensor): reward of the shape [B, len].
num_actions (int): Number of actions.
Returns:
torch.Tensor: Tensor containing the calculated advantages for each action.
"""
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(num_actions)):
nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
return advantages
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
"""
Generates an experience using the given input_ids and attention_mask.
Args:
input_ids (torch.Tensor): The input tensor containing the tokenized input sequence.
attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to.
**generate_kwargs: Additional keyword arguments for the generation process.
Returns:
Experience: The generated experience object.
"""
self.actor.eval()
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
pad_token_id = self.tokenizer.pad_token_id
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
torch.manual_seed(41) # for tp, gurantee the same input for reward model
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
# Pad to max length
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
sequence_length = sequences.size(1)
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_index = find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
)
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_ids) :] = False
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i]
sequence_to_pad = sequence[bos_index:eos_index]
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
value = value[:, -num_actions:] * action_mask
advantages = self.calculate_advantage(value, reward, num_actions)
advantages = advantages.detach()
value = value.detach()
r = r.detach()
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)