Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6198/head
YeAnbang 2025-02-18 09:43:36 +08:00 committed by GitHub
parent ce0ec40811
commit d20c8ffd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 1995 additions and 277 deletions

View File

@ -61,5 +61,6 @@ jobs:
PRETRAINED_MODEL_PATH: ./models
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PROMPT_RLVR_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data

View File

@ -158,6 +158,7 @@ temp/
applications/ColossalChat/logs
applications/ColossalChat/models
applications/ColossalChat/sft_data
applications/ColossalChat/kto_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp

View File

@ -141,7 +141,7 @@ def setup_conversation_template(
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")

View File

@ -155,13 +155,14 @@ class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
"""
gt_answer = [ins.get("gt_answer", None) for ins in instances]
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
ret = super().__call__(instances=instances)
input_ids = F.pad(
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
)
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}
@dataclass

View File

@ -147,7 +147,6 @@ def tokenize_prompt(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
@ -167,7 +166,6 @@ def tokenize_prompt(
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
@ -185,12 +183,21 @@ def tokenize_prompt(
)
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
if "gt_answer" in data_point:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
gt_answer=data_point["gt_answer"],
)
else:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):

View File

@ -27,6 +27,8 @@ class NaiveExperienceBuffer(ExperienceBuffer):
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
self.rng_sequence = []
self.ptr = 0
@torch.no_grad()
def append(self, experience: Experience) -> None:
@ -40,6 +42,9 @@ class NaiveExperienceBuffer(ExperienceBuffer):
if samples_to_remove > 0:
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
self.rng_sequence = [i for i in range(len(self.items))]
random.shuffle(self.rng_sequence)
self.ptr = 0
def clear(self) -> None:
self.items.clear()
@ -52,7 +57,10 @@ class NaiveExperienceBuffer(ExperienceBuffer):
Returns:
A batch of sampled experiences.
"""
items = random.sample(self.items, self.sample_batch_size)
items = []
for _ in range(self.sample_batch_size):
self.ptr = (self.ptr + 1) % len(self.items)
items.append(self.items[self.rng_sequence[self.ptr]])
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)

View File

@ -2,6 +2,8 @@
experience maker.
"""
from typing import Any
import torch
import torch.nn.functional as F
from coati.dataset.utils import find_first_occurrence_subsequence
@ -38,14 +40,27 @@ class NaiveExperienceMaker(ExperienceMaker):
kl_coef: float = 0.01,
gamma: float = 1.0,
lam: float = 0.95,
use_grpo: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = 2,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.gamma = gamma
self.lam = lam
self.use_grpo = use_grpo
self.num_generation = num_generation
self.inference_batch_size = inference_batch_size
self.logits_forward_batch_size = logits_forward_batch_size
if not self.use_grpo:
assert self.critic is not None, "Critic model is required for PPO training."
else:
assert self.critic is None, "Critic model is not required for GRPO training."
assert self.num_generation > 1, "Number of generations should be greater than 1 for GRPO training."
@torch.no_grad()
@torch.inference_mode()
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
"""
Calculates the advantage values for each action based on the value and reward tensors.
@ -69,7 +84,9 @@ class NaiveExperienceMaker(ExperienceMaker):
return advantages
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
def make_experience(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, gt_answer: Any = None, **generate_kwargs
) -> Experience:
"""
Generates an experience using the given input_ids and attention_mask.
@ -83,98 +100,204 @@ class NaiveExperienceMaker(ExperienceMaker):
"""
self.actor.eval()
self.critic.eval()
if self.critic:
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
pad_token_id = self.tokenizer.pad_token_id
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
if isinstance(stop_token_ids, int):
stop_token_ids = [[stop_token_ids]]
elif isinstance(stop_token_ids[0], int):
stop_token_ids = [stop_token_ids]
elif isinstance(stop_token_ids[0], list):
pass
else:
raise ValueError(
f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}"
)
generate_kwargs["stop_token_ids"] = stop_token_ids
torch.manual_seed(41) # for tp, gurantee the same input for reward model
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
if self.use_grpo and self.num_generation > 1:
# Generate multiple responses for each prompt
input_ids = input_ids.repeat_interleave(self.num_generation, dim=0)
gt_answer_tmp = []
for t in gt_answer:
gt_answer_tmp.extend([t] * self.num_generation)
gt_answer = gt_answer_tmp
if self.inference_batch_size is None:
self.inference_batch_size = input_ids.size(0)
# Pad to max length
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
sequence_length = sequences.size(1)
batch_sequences = []
batch_input_ids_rm = []
batch_attention_mask_rm = []
batch_attention_mask = []
batch_r = []
batch_action_log_probs = []
batch_base_action_log_probs = []
batch_action_mask = []
num_actions = 0
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
# pad to max_len, you don't want to get an OOM error after a thousands of steps
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_index = find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
)
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
# Pad to max length
sequence_length = sequences.size(1)
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_ids) :] = False
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i]
sequence_to_pad = sequence[bos_index:eos_index]
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_token_pos = [
find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_id).to(sequences.device)
)
for stop_token_id in stop_token_ids
]
stop_index = min([i for i in stop_token_pos if i != -1], default=-1)
stop_token_id = stop_token_ids[stop_token_pos.index(stop_index)]
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
print(self.tokenizer.decode(sequences[i], skip_special_tokens=True))
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_id) :] = False
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
torch.cuda.empty_cache()
with torch.inference_mode():
actor_output = []
base_model_output = []
for i in range(0, sequences.size(0), self.logits_forward_batch_size):
actor_output.append(
self.actor(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
base_model_output.append(
self.initial_model(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
actor_output = torch.cat(actor_output, dim=0)
base_model_output = torch.cat(base_model_output, dim=0)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
value = value[:, -num_actions:] * action_mask
advantages = self.calculate_advantage(value, reward, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
response_start = []
response_end = []
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i] + 1 # include the stop token
sequence_to_pad = sequence[bos_index:eos_index]
response_start.append(input_len - bos_index)
response_end.append(eos_index - bos_index)
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
advantages = advantages.detach()
value = value.detach()
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
response_start=response_start,
response_end=response_end,
gt_answer=gt_answer[s:e],
)
batch_sequences.append(sequences)
batch_input_ids_rm.append(input_ids_rm)
batch_attention_mask_rm.append(attention_mask_rm)
batch_attention_mask.append(attention_mask)
batch_r.append(r)
batch_action_log_probs.append(action_log_probs.cpu())
batch_base_action_log_probs.append(base_action_log_probs.cpu())
batch_action_mask.append(action_mask)
sequences = torch.cat(batch_sequences, dim=0)
input_ids_rm = torch.cat(batch_input_ids_rm, dim=0)
attention_mask_rm = torch.cat(batch_attention_mask_rm, dim=0)
attention_mask = torch.cat(batch_attention_mask, dim=0)
r = torch.cat(batch_r, dim=0)
action_log_probs = torch.cat(batch_action_log_probs, dim=0).to(sequences.device)
base_action_log_probs = torch.cat(batch_base_action_log_probs, dim=0).to(sequences.device)
action_mask = torch.cat(batch_action_mask, dim=0).to(sequences.device)
if not self.use_grpo:
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
value = value[:, -num_actions:] * action_mask
reward, kl = compute_reward(
r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask
)
advantages = self.calculate_advantage(value, reward, num_actions)
advantages = advantages.detach()
value = value.detach()
else:
# GRPO advantage calculation
kl = torch.sum(
-self.kl_coef * (action_log_probs - base_action_log_probs) * action_mask, dim=-1
) / torch.sum(
action_mask, dim=-1
) # address numerical instability issue
r = kl + r
mean_gr = r.view(-1, self.num_generation).mean(dim=1)
std_gr = r.view(-1, self.num_generation).std(dim=1)
mean_gr = mean_gr.repeat_interleave(self.num_generation, dim=0)
std_gr = std_gr.repeat_interleave(self.num_generation, dim=0)
advantages = (r - mean_gr) / (std_gr + 1e-4)
value = r.detach() # dummy value
r = r.detach()
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
return Experience(
sequences.cpu(),
action_log_probs.cpu(),
value.cpu(),
r.cpu(),
kl.cpu(),
advantages.cpu(),
attention_mask.cpu(),
action_mask.cpu(),
)

View File

@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .rlvr_reward_model import RLVRRewardModel
from .utils import disable_dropout
__all__ = [
"BaseModel",
"Critic",
"RewardModel",
"RLVRRewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",

View File

@ -1,3 +1,4 @@
import copy
from typing import Any, Callable, List, Optional
import torch
@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
return model_kwargs
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
model_kwargs["input_ids"] = input_ids
return model_kwargs
def _sample(
model: Any,
tokenizer: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = True,
@ -137,8 +139,8 @@ def _sample(
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
print("Exeeded length limitation")
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
past = None
@ -183,18 +185,14 @@ def _sample(
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
for stop_token_id in stop_token_ids:
tokens_to_check = input_ids[:, -len(stop_token_id) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
if i == context_length + max_new_tokens - 1:
# Force to end with stop token ids
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
)
return input_ids
@ -237,8 +235,10 @@ def generate(
raise NotImplementedError
elif is_sample_gen_mode:
# Run sample
generation_kwargs = copy.deepcopy(model_kwargs)
res = _sample(
model,
tokenizer,
input_ids,
max_length,
early_stopping=early_stopping,
@ -249,8 +249,9 @@ def generate(
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
**generation_kwargs,
)
del generation_kwargs
return res
elif is_beam_gen_mode:
raise NotImplementedError
@ -350,11 +351,17 @@ def _sample_streaming(
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
if isinstance(stop_token_ids[0], int):
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
else:
for stop_token_id in stop_token_ids:
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (

View File

@ -25,7 +25,9 @@ class RewardModel(BaseModel):
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
outputs = self.model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]

View File

@ -0,0 +1,50 @@
"""
reward model
"""
from typing import Callable, List, Optional
import torch
class RLVRRewardModel:
"""
RLVRReward model class. Support varifiable reward.
Args:
reward_fn_list List: list of reward functions
**kwargs: all other kwargs as in reward functions
"""
def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None:
self.reward_fn_list = reward_fn_list
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
response_start: List = None,
response_end: List = None,
gt_answer: List = None,
) -> torch.Tensor:
# apply varifiable reward
bs = input_ids.size(0)
rewards = torch.zeros(bs, device=input_ids.device)
for i in range(bs):
for reward_fn in self.reward_fn_list:
rewards[i] += reward_fn(
input_ids[i],
attention_mask[i],
response_start=response_start[i],
response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)
return rewards
def to(self, device):
return self
def eval(self):
return self

View File

@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module):
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
def repad_to_left(tensor, tokenizer):
repadded_input_ids = []
max_non_padded_seq_len = 0
for i in range(tensor.size(0)):
non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
start, end = non_pad_indices.min(), non_pad_indices.max()
repadded_input_ids.append(tensor[i][start : end + 1])
max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0))
repadded_input_ids = [
F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids
]
return torch.stack(repadded_input_ids)

View File

@ -1,5 +1,6 @@
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
from .grpo import GRPOTrainer
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
@ -15,4 +16,5 @@ __all__ = [
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
"GRPOTrainer",
]

View File

@ -96,6 +96,7 @@ class OLTrainer(ABC):
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.num_train_step = 0
@contextmanager
def _fit_ctx(self) -> None:
@ -212,5 +213,6 @@ class OLTrainer(ABC):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
self._save_checkpoint(episode + 1)
if self.num_train_step > 0 and (self.num_train_step + 1) % (self.save_interval) == 0:
self._save_checkpoint(self.num_train_step + 1)

View File

@ -343,7 +343,7 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
@ -358,26 +358,27 @@ class DPOTrainer(SLTrainer):
)
step_bar.update()
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
global_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()

View File

@ -0,0 +1,386 @@
"""
GRPO trainer
"""
import os
from typing import Dict, List, Optional, Union
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models import RewardModel, RLVRRewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import OLTrainer
from .utils import AnnealingScheduler, CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
"""
Set default keyword arguments for generation based on the actor model.
Args:
actor (PreTrainedModel): The actor model.
Returns:
Dict: A dictionary containing the default keyword arguments for generation.
"""
unwrapped_model = actor.unwrap()
new_kwargs = {}
# use huggingface models method directly
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
return new_kwargs
class GRPOTrainer(OLTrainer):
"""
Trainer for GRPO algorithm.
Args:
strategy (Booster): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(
self,
actor_booster: Booster,
actor: PreTrainedModel,
reward_model: Union[RewardModel, RLVRRewardModel],
initial_model: PreTrainedModel,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.2,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
use_tp: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = None,
temperature_annealing_config: Optional[Dict] = None,
coordinator: DistCoordinator = None,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(actor_booster, GeminiPlugin):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(actor_booster, None, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks)
self.generate_kwargs = _set_default_generate_kwargs(actor)
self.generate_kwargs.update(generate_kwargs)
self.actor = actor
self.actor_booster = actor_booster
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.experience_maker = NaiveExperienceMaker(
self.actor,
None,
reward_model,
initial_model,
self.tokenizer,
kl_coef,
use_grpo=True,
num_generation=num_generation,
inference_batch_size=inference_batch_size,
logits_forward_batch_size=logits_forward_batch_size,
)
if temperature_annealing_config:
# use annealing
self.temperature_annealing_scheduler = AnnealingScheduler(
temperature_annealing_config["start_temperature"],
temperature_annealing_config["end_temperature"],
temperature_annealing_config["annealing_warmup_steps"],
temperature_annealing_config["annealing_steps"],
)
else:
self.temperature_annealing_scheduler = None
self.train_batch_size = train_batch_size
self.actor_loss_fn = PolicyLoss(eps_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = GPTLMLoss()
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.use_tp = use_tp
self.accumulative_meter = AccumulativeMeanMeter()
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-grpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "grpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _setup_update_phrase_dataload(self):
"""
why not use distributed_dataloader?
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
"""
self.dataloader = DataLoader(
self.data_buffer,
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.data_buffer.collate_fn,
)
def _make_experience(self, collect_step: int) -> Experience:
"""
Make experience
"""
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
if self.temperature_annealing_scheduler:
self.generate_kwargs["temperature"] = self.temperature_annealing_scheduler.get_temperature()
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
gt_answer=prompts["gt_answer"],
**self.generate_kwargs,
)
def _training_step(self, experience: Experience):
"""
Args:
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
"logits"
] # [batch size, prompt_length + response_length]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs,
experience.action_log_probs,
experience.advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
# sequence that is not end properly are not counted in token cost
token_cost = torch.sum(
(experience.sequences[:, -num_actions:] != self.tokenizer.pad_token_id).to(torch.float), axis=-1
).to(actor_logits.device)
end_properly = experience.sequences[:, -1] == self.tokenizer.pad_token_id
mean_token_cost = torch.sum(token_cost * end_properly) / torch.sum(end_properly)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
ptx_loss = outputs.loss
ptx_loss = self.ptx_coef * ptx_loss
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
# sync
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
mean_token_cost = all_reduce_mean(tensor=mean_token_cost)
if self.ptx_coef != 0:
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
self.accumulative_meter.add("mean_token_cost", mean_token_cost.to(torch.float16).item())
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
if self.ptx_coef != 0:
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.actor_optim.step()
self.actor_optim.zero_grad()
self.actor_scheduler.step()
if self.temperature_annealing_scheduler:
self.temperature_annealing_scheduler.step_forward()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
for i in range(len(response_text)):
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
# log output to wandb
my_table = wandb.Table(
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
)
try:
self.wandb_run.log({"sample_response": my_table})
except OSError as e:
self.coordinator.print_on_master(e)
elif self.writer and is_rank_0():
for line in response_text:
self.coordinator.print_on_master(line)
if self.writer and is_rank_0():
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), global_step)
self.writer.add_scalar("train/skip_ratio", self.accumulative_meter.get("skip_ratio"), global_step)
self.writer.add_scalar("train/actor_loss", self.accumulative_meter.get("actor_loss"), global_step)
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], global_step)
if self.ptx_coef != 0:
self.writer.add_scalar("train/ptx_loss", self.accumulative_meter.get("ptx_loss"), global_step)
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), global_step)
self.writer.add_scalar("token_cost", self.accumulative_meter.get("mean_token_cost"), global_step)
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
def _learn(self, update_step: int):
"""
Perform the learning step of the PPO algorithm.
Args:
update_step (int): The current update step.
Returns:
None
"""
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
def _save_checkpoint(self, num_train_step: int = 0):
"""
Save the actor checkpoints with running states.
Args:
num_train_step (int): The current num_train_step number.
Returns:
None
"""
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
save_checkpoint(
save_dir=self.actor_save_dir,
booster=self.actor_booster,
model=self.actor,
optimizer=self.actor_optim,
lr_scheduler=self.actor_scheduler,
epoch=0,
step=num_train_step + 1,
batch_size=self.train_batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved actor checkpoint at episode {(num_train_step + 1)} at folder {self.actor_save_dir}"
)

View File

@ -217,25 +217,25 @@ class KTOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.accumulative_meter.reset()
@ -256,6 +256,7 @@ class KTOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()

View File

@ -184,35 +184,35 @@ class ORPOTrainer(SLTrainer):
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
step_bar.update()
global_step = (self.num_train_step + 1) / self.accumulation_steps
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/log_odds_ratio",
self.accumulative_meter.get("log_odds_ratio"),
self.num_train_step,
global_step,
)
self.accumulative_meter.reset()
@ -233,6 +233,7 @@ class ORPOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()

View File

@ -3,13 +3,13 @@ PPO trainer
"""
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models import Critic, RewardModel
from coati.models import Critic, RewardModel, RLVRRewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
@ -84,7 +84,7 @@ class PPOTrainer(OLTrainer):
critic_booster: Booster,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
reward_model: Union[RewardModel, RLVRRewardModel],
initial_model: PreTrainedModel,
actor_optim: Optimizer,
critic_optim: Optimizer,
@ -210,6 +210,7 @@ class PPOTrainer(OLTrainer):
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
gt_answer=prompts["gt_answer"],
**self.generate_kwargs,
)

View File

@ -150,29 +150,29 @@ class RewardModelTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.update()
self.num_train_step += 1
# Logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/dist",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), global_step
)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), global_step)
self.accumulative_meter.reset()
@ -193,6 +193,7 @@ class RewardModelTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()
def _eval(self, epoch):

View File

@ -143,15 +143,15 @@ class SFTTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
global_step = (self.num_train_step + 1) / self.accumulation_steps
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()

View File

@ -12,6 +12,27 @@ from torch.utils.data import DataLoader
from colossalai.booster import Plugin
class AnnealingScheduler:
def __init__(self, start, end, warmup_steps=100, annealing_step=2000):
self.start = start
self.end = end
self.warmup_steps = warmup_steps
self.step = 0
self.annealing_step = annealing_step
def get_temperature(self):
if self.step <= self.warmup_steps:
return self.start # Stop annealing after warm-up steps
elif self.step >= self.annealing_step:
return self.end
# Linear annealing
temp = self.start - (self.step / self.annealing_step) * (self.start - self.end)
return temp
def step_forward(self):
self.step += 1
class CycledDataLoader:
"""
A data loader that cycles through the data when it reaches the end.

View File

@ -0,0 +1,4 @@
from .competition import math_competition_reward_fn
from .gsm8k import gsm8k_reward_fn
__all__ = ["gsm8k_reward_fn", "math_competition_reward_fn"]

View File

@ -0,0 +1,26 @@
import torch
from .utils import extract_solution, validate_response_structure
def math_competition_reward_fn(input_ids, attention_mask, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward

View File

@ -0,0 +1,31 @@
import torch
from .utils import extract_solution, validate_response_structure
def gsm8k_reward_fn(input_ids, attention_mask, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward

View File

@ -0,0 +1,76 @@
# Copyright Unakar
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, Optional, Tuple
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
validation_passed = True
# Check required tags
if tags is None:
tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
positions = {}
for tag_name, tag_info in tags.items():
tag_str = tag_info["text"]
expected_count = tag_info["num_occur"]
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
if count != expected_count:
validation_passed = False
# Verify tag order
if (
positions["think_start"] > positions["think_end"]
or positions["think_end"] > positions["answer_start"]
or positions["answer_start"] > positions["answer_end"]
):
validation_passed = False
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
validation_passed = False
return validation_passed
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
"""Extracts the final answer from the model's response string.
Args:
solution_str: Raw response string from the language model
Returns:
Tuple containing (extracted_answer, processed_string)
"""
# Extract final answer using XML-style tags
answer_pattern = r"<answer>(.*?)</answer>"
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
if not matches:
return None, solution_str
final_answer = matches[-1].group(1).strip()
return final_answer, solution_str

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
122753
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -0,0 +1,26 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n",
"stop_ids": [
151643
],
"end_of_assistant": "<|endoftext|>",
"response_format_tags": {
"think_start": {
"text": "<think>",
"num_occur": 1
},
"think_end": {
"text": "</think>",
"num_occur": 1
},
"answer_start": {
"text": "<answer>",
"num_occur": 1
},
"answer_end": {
"text": "</answer>",
"num_occur": 1
}
}
}

View File

@ -27,6 +27,7 @@
- [Reward](#reward)
- [KL Divergence](#approximate-kl-divergence)
- [Note on PPO Training](#note-on-ppo-training)
- [GRPO Training and DeepSeek R1 reproduction]
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
@ -725,6 +726,75 @@ Answer: The causes of this problem are two-fold. Check your reward model, make s
#### Q4: Generation is garbage
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
## GRPO Training and DeepSeek R1 reproduction
We support GRPO (Group Relative Policy Optimization), which is the reinforcement learning algorithm used in DeepSeek R1 paper. In this section, we will walk through GRPO training with an example trying to reproduce Deepseek R1's results in mathematical problem solving.
### GRPO Model Selection
We finally select the base version of [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B). We also did experiments on the instruct version [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) but the later one fails to explore more diversed output. We recommend to use base models (without SFT) and use a few SFT steps (see [SFT section](#rlhf-training-stage1---supervised-instructs-tuning)) to correct the base model's output format before GRPO.
### Reinforcement Learning with Verifiable Reward
Both the PPO and the GRPO support reinforcement learning with verifiable reward (RLVR). In this experiment on mathematical problem solving, we define the reward function as following, in the following definition, forward is correct if there are exactly one pair of <think></think>, <answer></answer> tags in the response and the order of the tags is correct.
- reward=0, if format is incorrect.
- reward=1, if format is correct but the answer doesn't match the ground truth answer exactly.
- reward=10, if format is correct and the answer match the ground truth answer exactly.
### Step 1: Data Collection & Preparation
For GPRO training, you only need the prompt dataset. Please follow the instruction in the [prompt dataset preparation](#rlhf-training-stage3---proximal-policy-optimization) to prepare the prompt data for GPRO training. In our reproduction experiment, we use the [qwedsacf/competition_math dataset](https://huggingface.co/datasets/qwedsacf/competition_math), which is available on Huggingface.
### Step 2: Training
You can run the [train_grpo.sh](./training_scripts/train_grpo.sh) to start GRPO training. The script share most of its arguments with the PPO script (please refer to the [PPO training section](#step-3-training) for more details). Here are some unique arguments for GRPO.
```bash
--num_generations 8 \ # number of roll outs to collect for each prompt
--inference_batch_size 8 \ # batch size used during roll out
--logits_forward_batch_size 1 \ # batch size used to calculate logits for GRPO training
--initial_temperature \ # initial temperature for annealing algorithm
--final_temperature \ # final temperature for annealing algorithm
```
As the GRPO requires to collect a group of response from each prompt (usually greater than 8), the effective batch size will satisfy the following constraints,
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_tp_group
```
During roll out, we perform rebatching to prevent out of memory both before roll out and before calculating logits. Please choose a proper setting for the "inference_batch_size" and the "logits_forward_batch_size" based on your device.
### GRPO Result
#### Reward
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/reward.png">
</p>
#### Response Length
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost.png">
</p>
#### Response Length Distribution (After Training)
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost_eval.png">
</p>
#### Sample Response
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/res.png">
</p>
#### Note of Speed
Currently, our PPO and GRPO pipeline are still under development. The speed is largely limited by the roll out speed as we use naive generation without any acceleration.
## Alternative Option For RLHF: Direct Preference Optimization

View File

@ -11,4 +11,4 @@ python prepare_dataset.py --type prompt \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024
--max_length 300

View File

@ -1,4 +1,4 @@
pandas>=1.4.1
sentencepiece
colossalai==0.4.0
colossalai==0.4.7
prompt_toolkit

View File

@ -0,0 +1,494 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
import torch.distributed as dist
from coati.dataset import (
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import LoraConfig, RewardModel, RLVRRewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import GRPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
actor = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
if args.rm_pretrain:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
if args.rm_pretrain:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if args.lora_config is not None:
actor = convert_to_lora_module(actor, lora_config=lora_config)
for name, module in actor.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
lora_manager.able_to_merge = False
# Disable dropout
disable_dropout(actor)
if args.grad_checkpoint:
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if os.path.exists(args.conversation_template_config):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
else:
raise ValueError("Conversation template config is not provided or incorrect")
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "left" # left padding for generation (online learning)
# configure generation config
actor.generation_config.update(
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
)
# configure optimizer
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
actor_optim = HybridAdam(
model_params=actor.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(0.025 * args.num_episodes)
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
actor_lr_scheduler = CosineAnnealingWarmupLR(
optimizer=actor_optim,
total_steps=args.num_episodes,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
args.use_flash_attn = False
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
if args.rm_pretrain:
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
if args.plugin != "3d" and args.rm_pretrain:
custom_plugin = plugin
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = plugin.prepare_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = plugin.prepare_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
train_pretrain_dataloader = None
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
if args.rm_pretrain:
rm_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
dataloader=train_prompt_dataloader,
)
if args.rm_pretrain:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
sampler_start_idx = 0
start_step = 0
if args.rm_checkpoint_path is not None:
if "modeling" in args.rm_checkpoint_path:
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
else:
_, _, _ = load_checkpoint(
load_dir=args.rm_checkpoint_path,
booster=rm_booster,
model=reward_model,
optimizer=None,
lr_scheduler=None,
)
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
actor_booster.load_model(actor, args.checkpoint_path)
ref_booster.load_model(ref_model, args.checkpoint_path)
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
else:
_, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=actor_booster,
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
)
_, _, _ = load_checkpoint(load_dir=args.checkpoint_path, booster=ref_booster, model=ref_model)
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
# configure trainer
trainer = GRPOTrainer(
actor_booster,
actor,
reward_model,
ref_model,
actor_optim,
actor_lr_scheduler,
tokenizer=tokenizer,
stop_token_ids=[stop_ids],
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
buffer_limit=args.num_collect_steps * args.experience_batch_size * args.num_generations,
max_length=args.max_length,
use_cache=True,
do_sample=True,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
top_k=50,
use_tp=args.tp > 1,
num_generations=args.num_generations,
inference_batch_size=args.inference_batch_size,
logits_forward_batch_size=args.logits_forward_batch_size,
offload_inference_models="gemini" not in args.plugin,
coordinator=coordinator,
max_tokens_thinking=args.max_tokens_thinking if args.max_tokens_thinking else args.max_length - 100,
temperature_annealing_config={
"start_temperature": args.initial_temperature,
"end_temperature": args.final_temperature,
"annealing_warmup_steps": min(100, int(args.num_episodes / 6)),
"annealing_steps": min(600, int(args.num_episodes / 2)),
},
# Hack: some old model's default update_model_kwargs_fn/prepare_inputs_fn may doesn't work due to version conflict with transformers, you can overwrite them
# update_model_kwargs_fn=update_model_kwargs_fn,
# prepare_inputs_fn = None
)
trainer.fit(
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
prompt_dataloader=train_prompt_dataloader,
pretrain_dataloader=train_pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
lora_manager.able_to_merge = True
actor.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final actor model checkpoint")
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", nargs="+", default=[])
parser.add_argument("--ptx_dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument(
"--conversation_template_config",
type=str,
default=None,
help="Path \
to save conversation template config files.",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--inference_batch_size", type=int, default=None)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--logits_forward_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-6)
parser.add_argument("--kl_coef", type=float, default=0.7)
parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_tokens_thinking", type=int, default=2000)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--initial_temperature", type=float, default=1.0)
parser.add_argument("--final_temperature", type=float, default=0.9)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
train(args)

View File

@ -0,0 +1,86 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="PPO-RLVR"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
LOGDIR=""
declare -a prompt_dataset=(
YOUR/PROMPT/DATA/DIR/arrow/part-00000
YOUR/PROMPT/DATA/DIR/arrow/part-00001
YOUR/PROMPT/DATA/DIR/arrow/part-00002
YOUR/PROMPT/DATA/DIR/arrow/part-00003
YOUR/PROMPT/DATA/DIR/arrow/part-00004
YOUR/PROMPT/DATA/DIR/arrow/part-00005
YOUR/PROMPT/DATA/DIR/arrow/part-00006
YOUR/PROMPT/DATA/DIR/arrow/part-00007
YOUR/PROMPT/DATA/DIR/arrow/part-00008
YOUR/PROMPT/DATA/DIR/arrow/part-00009
)
declare -a ptx_dataset=(
YOUR/SFT/DATA/DIR/arrow/part-00000
YOUR/SFT/DATA/DIR/arrow/part-00001
YOUR/SFT/DATA/DIR/arrow/part-00002
YOUR/SFT/DATA/DIR/arrow/part-00003
YOUR/SFT/DATA/DIR/arrow/part-00004
YOUR/SFT/DATA/DIR/arrow/part-00005
YOUR/SFT/DATA/DIR/arrow/part-00006
YOUR/SFT/DATA/DIR/arrow/part-00007
YOUR/SFT/DATA/DIR/arrow/part-00008
YOUR/SFT/DATA/DIR/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --num_nodes 1 --hostfile ./hostfile train_grpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--prompt_dataset ${prompt_dataset[@]} \
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
--ptx_coef 0.0 \
--plugin "zero2_cpu" \
--reward_functions math_competition_reward_fn \
--save_interval 250 \
--save_path $SAVE_DIR \
--num_episodes 100 \
--num_collect_steps 8 \
--num_update_steps 1 \
--experience_batch_size 1 \
--train_batch_size 4 \
--inference_batch_size 8 \
--logits_forward_batch_size 2 \
--accumulation_steps 4 \
--lr 1e-6 \
--mixed_precision "bf16" \
--grad_clip 0.1\
--weight_decay 0.01 \
--kl_coef 0.01 \
--warmup_steps 40 \
--max_length 2000 \
--max_seq_len 1700 \
--log_dir $LOGDIR \
--use_flash_attn \
--grad_checkpoint

View File

@ -13,9 +13,18 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.models import (
Critic,
LoraConfig,
RewardModel,
RLVRRewardModel,
convert_to_lora_module,
disable_dropout,
lora_manager,
)
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
@ -61,28 +79,36 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
if not args.no_neural_reward_model:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
critic = Critic(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if not args.no_neural_reward_model:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
@ -112,6 +138,9 @@ def train(args):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
@ -245,7 +274,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
custom_policy=get_autopolicy(critic.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -284,7 +313,8 @@ def train(args):
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
if not args.no_neural_reward_model:
rm_booster = Booster(plugin=custom_plugin)
critic_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
@ -302,7 +332,28 @@ def train(args):
lr_scheduler=critic_lr_scheduler,
dataloader=train_prompt_dataloader,
)
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
if not args.no_neural_reward_model:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_fn_list.append(eval(reward_fn))
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
@ -481,9 +532,11 @@ if __name__ == "__main__":
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)

View File

@ -2,7 +2,7 @@ transformers==4.39.3
tqdm
datasets==2.14.7
loralib
colossalai>=0.4.0
colossalai>=0.4.7
torch>=2.1.0
langchain
tokenizers

View File

@ -20,6 +20,15 @@ prompt_seed = {
},
]
}
prompt_rlvr_seed = {
"messages": [
{
"from": "user",
"content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?",
},
],
"gt_answer": "4",
}
preference_seed = {
"context": [
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
@ -72,6 +81,8 @@ if __name__ == "__main__":
seed = sft_seed
elif args.data_type == "prompt":
seed = prompt_seed
elif args.data_type == "prompt_rlvr":
seed = prompt_rlvr_seed
elif args.data_type == "preference":
seed = preference_seed
elif args.data_type == "kto":

View File

@ -0,0 +1,16 @@
# run under /ColossalAI/applications/ColossalChat
export NCCL_SHM_DISABLE=1
export MAX_JOBS=1
export PRETRAINED_MODEL_PATH=./models
export SFT_DATASET=./sft_data
export PROMPT_DATASET=./prompt_data
export PROMPT_RLVR_DATASET=./prompt_data
export PREFERENCE_DATASET=./preference_data
export KTO_DATASET=./kto_data
mkdir models
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
# ./tests/test_data_preparation.sh
# ./tests/test_train.sh

View File

@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then
fi
if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_DATASET to the path to prompts."
echo "Please set \$PROMPT_DATASET to the path to prompts dataset."
exit 1
fi
if [ -z "$PROMPT_RLVR_DATASET" ]; then
echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels."
exit 1
fi
@ -69,6 +74,8 @@ get_data_input_dirs() {
echo "$SFT_DATASET"
elif [[ $data_type == "prompt" ]]; then
echo "$PROMPT_DATASET"
elif [[ $data_type == "prompt_rlvr" ]]; then
echo "$PROMPT_RLVR_DATASET"
elif [[ $data_type == "preference" ]]; then
echo "$PREFERENCE_DATASET"
elif [[ $data_type == "kto" ]]; then
@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt) \
--data_type "prompt"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt_rlvr) \
--data_type "prompt_rlvr"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs kto) \
--data_type "kto"
@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do
done
echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..."
# FIXME: This is a hack to skip tests that are not working
SKIPPED_TESTS=(
)
# test prepare_prompt_dataset
for model in ${MODELS[@]}; do
data_type="prompt_rlvr"
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
echo "[Test]: Skipped $model-$data_type"
continue
fi
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
data_input_dirs=$(get_data_input_dirs $data_type)
tokenizer_dir=$(get_tokenizer_dirs $model)
conversation_template=$(get_conversation_template_config $model)
for i in $(seq $NUM_RETRY); do
rm -rf $cache_dir
rm -rf $jsonl_dir
rm -rf $arrow_dir
echo "[Test]: $model-$data_type, attempt $i"
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
--type prompt \
--data_input_dirs $data_input_dirs \
--conversation_template_config $conversation_template \
--tokenizer_dir $tokenizer_dir \
--data_cache_dir $cache_dir \
--data_jsonl_output_dir $jsonl_dir \
--data_arrow_output_dir $arrow_dir \
--max_length 400 \
--num_samples_per_datafile 100 \
--num_spliced_dataset_bins 1
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$data_type"
exit 1
fi
done
echo "[Test]: testing prepare_kto_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working

View File

@ -81,8 +81,242 @@ random_choice() {
echo ${arr[$idx]}
}
echo "[Test]: testing grpo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('nn' 'vr')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
rm_pretrain="--rm_pretrain $pretrain"
reward_fn=""
if [[ $reward_type == "vr" ]]; then
rm_pretrain=""
reward_fn="--reward_functions gsm8k_reward_fn"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='1'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='1'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
--pretrain $pretrain \
$rm_pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--num_generations 2 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 200 \ \
--max_seq_len 10 \
$reward_fn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('vr' 'nn')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
reward_fn=""
no_nn=""
if [[ $reward_type == "vr" ]]; then
reward_fn="--reward_functions gsm8k_reward_fn"
no_nn="--no_neural_reward_model"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='2'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='2'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
$reward_fn \
$no_nn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing sft ..."
@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='4'
ebs='8'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='16'
ebs='32'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing DPO ..."
SKIPPED_TESTS=(
@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
echo "[Test]: testing ORPO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing KTO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE