ColossalAI/applications/ColossalChat/coati/experience_maker/naive.py

304 lines
14 KiB
Python
Executable File

"""
experience maker.
"""
from typing import Any
import torch
import torch.nn.functional as F
from coati.dataset.utils import find_first_occurrence_subsequence
from coati.models import Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedModel, PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .base import Experience, ExperienceMaker
logger = get_dist_logger()
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
class NaiveExperienceMaker(ExperienceMaker):
"""
Naive experience maker.
"""
def __init__(
self,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
initial_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.01,
gamma: float = 1.0,
lam: float = 0.95,
use_grpo: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = 2,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.gamma = gamma
self.lam = lam
self.use_grpo = use_grpo
self.num_generation = num_generation
self.inference_batch_size = inference_batch_size
self.logits_forward_batch_size = logits_forward_batch_size
if not self.use_grpo:
assert self.critic is not None, "Critic model is required for PPO training."
else:
assert self.critic is None, "Critic model is not required for GRPO training."
assert self.num_generation > 1, "Number of generations should be greater than 1 for GRPO training."
@torch.inference_mode()
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
"""
Calculates the advantage values for each action based on the value and reward tensors.
Args:
value (torch.Tensor): Tensor containing the predicted values from critic.
reward (torch.Tensor): reward of the shape [B, len].
num_actions (int): Number of actions.
Returns:
torch.Tensor: Tensor containing the calculated advantages for each action.
"""
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(num_actions)):
nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
return advantages
@torch.no_grad()
def make_experience(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, gt_answer: Any = None, **generate_kwargs
) -> Experience:
"""
Generates an experience using the given input_ids and attention_mask.
Args:
input_ids (torch.Tensor): The input tensor containing the tokenized input sequence.
attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to.
**generate_kwargs: Additional keyword arguments for the generation process.
Returns:
Experience: The generated experience object.
"""
self.actor.eval()
if self.critic:
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
pad_token_id = self.tokenizer.pad_token_id
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
if isinstance(stop_token_ids, int):
stop_token_ids = [[stop_token_ids]]
elif isinstance(stop_token_ids[0], int):
stop_token_ids = [stop_token_ids]
elif isinstance(stop_token_ids[0], list):
pass
else:
raise ValueError(
f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}"
)
generate_kwargs["stop_token_ids"] = stop_token_ids
torch.manual_seed(41) # for tp, gurantee the same input for reward model
if self.use_grpo and self.num_generation > 1:
# Generate multiple responses for each prompt
input_ids = input_ids.repeat_interleave(self.num_generation, dim=0)
gt_answer_tmp = []
for t in gt_answer:
gt_answer_tmp.extend([t] * self.num_generation)
gt_answer = gt_answer_tmp
if self.inference_batch_size is None:
self.inference_batch_size = input_ids.size(0)
batch_sequences = []
batch_input_ids_rm = []
batch_attention_mask_rm = []
batch_attention_mask = []
batch_r = []
batch_action_log_probs = []
batch_base_action_log_probs = []
batch_action_mask = []
num_actions = 0
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
# pad to max_len, you don't want to get an OOM error after a thousands of steps
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
# Pad to max length
sequence_length = sequences.size(1)
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_token_pos = [
find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_id).to(sequences.device)
)
for stop_token_id in stop_token_ids
]
stop_index = min([i for i in stop_token_pos if i != -1], default=-1)
stop_token_id = stop_token_ids[stop_token_pos.index(stop_index)]
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
print(self.tokenizer.decode(sequences[i], skip_special_tokens=True))
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_id) :] = False
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
torch.cuda.empty_cache()
with torch.inference_mode():
actor_output = []
base_model_output = []
for i in range(0, sequences.size(0), self.logits_forward_batch_size):
actor_output.append(
self.actor(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
base_model_output.append(
self.initial_model(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
actor_output = torch.cat(actor_output, dim=0)
base_model_output = torch.cat(base_model_output, dim=0)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
response_start = []
response_end = []
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i] + 1 # include the stop token
sequence_to_pad = sequence[bos_index:eos_index]
response_start.append(input_len - bos_index)
response_end.append(eos_index - bos_index)
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
response_start=response_start,
response_end=response_end,
gt_answer=gt_answer[s:e],
)
batch_sequences.append(sequences)
batch_input_ids_rm.append(input_ids_rm)
batch_attention_mask_rm.append(attention_mask_rm)
batch_attention_mask.append(attention_mask)
batch_r.append(r)
batch_action_log_probs.append(action_log_probs.cpu())
batch_base_action_log_probs.append(base_action_log_probs.cpu())
batch_action_mask.append(action_mask)
sequences = torch.cat(batch_sequences, dim=0)
input_ids_rm = torch.cat(batch_input_ids_rm, dim=0)
attention_mask_rm = torch.cat(batch_attention_mask_rm, dim=0)
attention_mask = torch.cat(batch_attention_mask, dim=0)
r = torch.cat(batch_r, dim=0)
action_log_probs = torch.cat(batch_action_log_probs, dim=0).to(sequences.device)
base_action_log_probs = torch.cat(batch_base_action_log_probs, dim=0).to(sequences.device)
action_mask = torch.cat(batch_action_mask, dim=0).to(sequences.device)
if not self.use_grpo:
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
value = value[:, -num_actions:] * action_mask
reward, kl = compute_reward(
r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask
)
advantages = self.calculate_advantage(value, reward, num_actions)
advantages = advantages.detach()
value = value.detach()
else:
# GRPO advantage calculation
kl = torch.sum(
-self.kl_coef * (action_log_probs - base_action_log_probs) * action_mask, dim=-1
) / torch.sum(
action_mask, dim=-1
) # address numerical instability issue
r = kl + r
mean_gr = r.view(-1, self.num_generation).mean(dim=1)
std_gr = r.view(-1, self.num_generation).std(dim=1)
mean_gr = mean_gr.repeat_interleave(self.num_generation, dim=0)
std_gr = std_gr.repeat_interleave(self.num_generation, dim=0)
advantages = (r - mean_gr) / (std_gr + 1e-4)
value = r.detach() # dummy value
r = r.detach()
return Experience(
sequences.cpu(),
action_log_probs.cpu(),
value.cpu(),
r.cpu(),
kl.cpu(),
advantages.cpu(),
attention_mask.cpu(),
action_mask.cpu(),
)