From d20c8ffd975ab6bd5ccaebcc13bb38b5f49cb279 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:43:36 +0800 Subject: [PATCH] Add GRPO and Support RLVR for PPO (#6186) * add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/run_chatgpt_examples.yml | 1 + applications/ColossalChat/.gitignore | 1 + .../coati/dataset/conversation.py | 2 +- .../ColossalChat/coati/dataset/loader.py | 3 +- .../coati/dataset/tokenization_utils.py | 23 +- .../coati/experience_buffer/naive.py | 10 +- .../coati/experience_maker/naive.py | 287 +++++++--- .../ColossalChat/coati/models/__init__.py | 2 + .../ColossalChat/coati/models/generation.py | 39 +- .../ColossalChat/coati/models/reward_model.py | 4 +- .../coati/models/rlvr_reward_model.py | 50 ++ .../ColossalChat/coati/models/utils.py | 14 + .../ColossalChat/coati/trainer/__init__.py | 2 + .../ColossalChat/coati/trainer/base.py | 6 +- .../ColossalChat/coati/trainer/dpo.py | 15 +- .../ColossalChat/coati/trainer/grpo.py | 386 ++++++++++++++ .../ColossalChat/coati/trainer/kto.py | 15 +- .../ColossalChat/coati/trainer/orpo.py | 19 +- .../ColossalChat/coati/trainer/ppo.py | 7 +- applications/ColossalChat/coati/trainer/rm.py | 17 +- .../ColossalChat/coati/trainer/sft.py | 8 +- .../ColossalChat/coati/trainer/utils.py | 21 + .../coati/utils/reward_score/__init__.py | 4 + .../coati/utils/reward_score/competition.py | 26 + .../coati/utils/reward_score/gsm8k.py | 31 ++ .../coati/utils/reward_score/utils.py | 76 +++ .../conversation_template/MiniCPM-2b.json | 8 + .../Qwen_Qwen2.5-3B.json | 26 + applications/ColossalChat/examples/README.md | 70 +++ .../prepare_prompt_dataset.sh | 2 +- .../ColossalChat/examples/requirements.txt | 2 +- .../examples/training_scripts/train_grpo.py | 494 ++++++++++++++++++ .../examples/training_scripts/train_grpo.sh | 86 +++ .../examples/training_scripts/train_ppo.py | 77 ++- applications/ColossalChat/requirements.txt | 2 +- .../generate_dummy_datasets_for_testing.py | 11 + .../ColossalChat/tests/prepare_test_env.sh | 16 + .../tests/test_data_preparation.sh | 59 ++- applications/ColossalChat/tests/test_train.sh | 350 +++++++++---- 39 files changed, 1995 insertions(+), 277 deletions(-) create mode 100644 applications/ColossalChat/coati/models/rlvr_reward_model.py create mode 100755 applications/ColossalChat/coati/trainer/grpo.py create mode 100644 applications/ColossalChat/coati/utils/reward_score/__init__.py create mode 100644 applications/ColossalChat/coati/utils/reward_score/competition.py create mode 100644 applications/ColossalChat/coati/utils/reward_score/gsm8k.py create mode 100644 applications/ColossalChat/coati/utils/reward_score/utils.py create mode 100644 applications/ColossalChat/conversation_template/MiniCPM-2b.json create mode 100644 applications/ColossalChat/conversation_template/Qwen_Qwen2.5-3B.json create mode 100755 applications/ColossalChat/examples/training_scripts/train_grpo.py create mode 100755 applications/ColossalChat/examples/training_scripts/train_grpo.sh create mode 100755 applications/ColossalChat/tests/prepare_test_env.sh diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 262def229..7a70a16b9 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -61,5 +61,6 @@ jobs: PRETRAINED_MODEL_PATH: ./models SFT_DATASET: ./sft_data PROMPT_DATASET: ./prompt_data + PROMPT_RLVR_DATASET: ./prompt_data PREFERENCE_DATASET: ./preference_data KTO_DATASET: ./kto_data diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore index 5a4bb905f..bc8517372 100755 --- a/applications/ColossalChat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -158,6 +158,7 @@ temp/ applications/ColossalChat/logs applications/ColossalChat/models applications/ColossalChat/sft_data +applications/ColossalChat/kto_data applications/ColossalChat/prompt_data applications/ColossalChat/preference_data applications/ColossalChat/temp diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py index a77c220d3..21ab6fa1f 100755 --- a/applications/ColossalChat/coati/dataset/conversation.py +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -141,7 +141,7 @@ def setup_conversation_template( pass except ValueError as e: raise ValueError(e) - if not dist.is_initialized() or dist.get_rank() == 0: + if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0): os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, "w", encoding="utf8") as f: logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.") diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index b92cd76ad..cdbadab64 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -155,13 +155,14 @@ class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset): `input_ids`: `torch.Tensor` of shape (bsz, max_len); `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); """ + gt_answer = [ins.get("gt_answer", None) for ins in instances] instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances] ret = super().__call__(instances=instances) input_ids = F.pad( ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id ) attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False) - return {"input_ids": input_ids, "attention_mask": attention_mask} + return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer} @dataclass diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 020432b9e..893090edf 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -147,7 +147,6 @@ def tokenize_prompt( ignore_index: the ignore index when calculate loss during training max_length: the maximum context length """ - messages = data_point["messages"] template = deepcopy(conversation_template) template.messages = [] @@ -167,7 +166,6 @@ def tokenize_prompt( if len(template.messages) % 2 != 1: # exclude the answer if provided. keep only the prompt template.messages = template.messages[:-1] - # Prepare data prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True) tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] @@ -185,12 +183,21 @@ def tokenize_prompt( ) # `inputs_decode` can be used to check whether the tokenization method is true. - return dict( - input_ids=tokenized, - inputs_decode=prompt, - seq_length=len(tokenized), - seq_category=data_point["category"] if "category" in data_point else "None", - ) + if "gt_answer" in data_point: + return dict( + input_ids=tokenized, + inputs_decode=prompt, + seq_length=len(tokenized), + seq_category=data_point["category"] if "category" in data_point else "None", + gt_answer=data_point["gt_answer"], + ) + else: + return dict( + input_ids=tokenized, + inputs_decode=prompt, + seq_length=len(tokenized), + seq_category=data_point["category"] if "category" in data_point else "None", + ) def apply_rlhf_data_format(template: Conversation, tokenizer: Any): diff --git a/applications/ColossalChat/coati/experience_buffer/naive.py b/applications/ColossalChat/coati/experience_buffer/naive.py index b912df268..7194054a3 100755 --- a/applications/ColossalChat/coati/experience_buffer/naive.py +++ b/applications/ColossalChat/coati/experience_buffer/naive.py @@ -27,6 +27,8 @@ class NaiveExperienceBuffer(ExperienceBuffer): self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") # TODO(ver217): add prefetch self.items: List[BufferItem] = [] + self.rng_sequence = [] + self.ptr = 0 @torch.no_grad() def append(self, experience: Experience) -> None: @@ -40,6 +42,9 @@ class NaiveExperienceBuffer(ExperienceBuffer): if samples_to_remove > 0: logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.") self.items = self.items[samples_to_remove:] + self.rng_sequence = [i for i in range(len(self.items))] + random.shuffle(self.rng_sequence) + self.ptr = 0 def clear(self) -> None: self.items.clear() @@ -52,7 +57,10 @@ class NaiveExperienceBuffer(ExperienceBuffer): Returns: A batch of sampled experiences. """ - items = random.sample(self.items, self.sample_batch_size) + items = [] + for _ in range(self.sample_batch_size): + self.ptr = (self.ptr + 1) % len(self.items) + items.append(self.items[self.rng_sequence[self.ptr]]) experience = make_experience_batch(items) if self.cpu_offload: experience.to_device(self.target_device) diff --git a/applications/ColossalChat/coati/experience_maker/naive.py b/applications/ColossalChat/coati/experience_maker/naive.py index 945bb9557..c7ad4f316 100755 --- a/applications/ColossalChat/coati/experience_maker/naive.py +++ b/applications/ColossalChat/coati/experience_maker/naive.py @@ -2,6 +2,8 @@ experience maker. """ +from typing import Any + import torch import torch.nn.functional as F from coati.dataset.utils import find_first_occurrence_subsequence @@ -38,14 +40,27 @@ class NaiveExperienceMaker(ExperienceMaker): kl_coef: float = 0.01, gamma: float = 1.0, lam: float = 0.95, + use_grpo: bool = False, + num_generation: int = 8, + inference_batch_size: int = None, + logits_forward_batch_size: int = 2, ) -> None: super().__init__(actor, critic, reward_model, initial_model) self.tokenizer = tokenizer self.kl_coef = kl_coef self.gamma = gamma self.lam = lam + self.use_grpo = use_grpo + self.num_generation = num_generation + self.inference_batch_size = inference_batch_size + self.logits_forward_batch_size = logits_forward_batch_size + if not self.use_grpo: + assert self.critic is not None, "Critic model is required for PPO training." + else: + assert self.critic is None, "Critic model is not required for GRPO training." + assert self.num_generation > 1, "Number of generations should be greater than 1 for GRPO training." - @torch.no_grad() + @torch.inference_mode() def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor: """ Calculates the advantage values for each action based on the value and reward tensors. @@ -69,7 +84,9 @@ class NaiveExperienceMaker(ExperienceMaker): return advantages @torch.no_grad() - def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience: + def make_experience( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, gt_answer: Any = None, **generate_kwargs + ) -> Experience: """ Generates an experience using the given input_ids and attention_mask. @@ -83,98 +100,204 @@ class NaiveExperienceMaker(ExperienceMaker): """ self.actor.eval() - self.critic.eval() + if self.critic: + self.critic.eval() self.initial_model.eval() self.reward_model.eval() pad_token_id = self.tokenizer.pad_token_id - stop_token_ids = generate_kwargs.get("stop_token_ids", None) + if isinstance(stop_token_ids, int): + stop_token_ids = [[stop_token_ids]] + elif isinstance(stop_token_ids[0], int): + stop_token_ids = [stop_token_ids] + elif isinstance(stop_token_ids[0], list): + pass + else: + raise ValueError( + f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}" + ) + generate_kwargs["stop_token_ids"] = stop_token_ids torch.manual_seed(41) # for tp, gurantee the same input for reward model - sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs) + if self.use_grpo and self.num_generation > 1: + # Generate multiple responses for each prompt + input_ids = input_ids.repeat_interleave(self.num_generation, dim=0) + gt_answer_tmp = [] + for t in gt_answer: + gt_answer_tmp.extend([t] * self.num_generation) + gt_answer = gt_answer_tmp + if self.inference_batch_size is None: + self.inference_batch_size = input_ids.size(0) - # Pad to max length - sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id) - sequence_length = sequences.size(1) + batch_sequences = [] + batch_input_ids_rm = [] + batch_attention_mask_rm = [] + batch_attention_mask = [] + batch_r = [] + batch_action_log_probs = [] + batch_base_action_log_probs = [] + batch_action_mask = [] + num_actions = 0 - # Calculate auxiliary tensors - attention_mask = None - if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size): + s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size + if input_ids[s:e].size(0) == 0: + break + sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs) + # pad to max_len, you don't want to get an OOM error after a thousands of steps + sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id) - input_len = input_ids.size(1) - if stop_token_ids is None: - # End the sequence with eos token - eos_token_id = self.tokenizer.eos_token_id - if eos_token_id is None: - action_mask = torch.ones_like(sequences, dtype=torch.bool) - else: - # Left padding may be applied, only mask action - action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input - else: - # stop_token_ids are given, generation ends with stop_token_ids - action_mask = torch.ones_like(sequences, dtype=torch.bool) - for i in range(sequences.size(0)): - stop_index = find_first_occurrence_subsequence( - sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device) - ) - if stop_index == -1: - # Sequence does not contain stop_token_ids, this should never happen BTW - logger.warning( - "Generated sequence does not contain stop_token_ids. Please check your chat template config" - ) + # Pad to max length + sequence_length = sequences.size(1) + + # Calculate auxiliary tensors + attention_mask = None + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + + input_len = input_ids.size(1) + if stop_token_ids is None: + # End the sequence with eos token + eos_token_id = self.tokenizer.eos_token_id + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) else: - # Keep stop tokens - stop_index = input_len + stop_index - action_mask[i, stop_index + len(stop_token_ids) :] = False - - generation_end_index = (action_mask == True).sum(dim=-1) - 1 - action_mask[:, :input_len] = False - action_mask = action_mask[:, 1:] - action_mask = action_mask[:, -(sequences.size(1) - input_len) :] - num_actions = action_mask.size(1) - - actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"] - action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) - - base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"] - - base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) - - # Convert to right padding for the reward model and the critic model - input_ids_rm = torch.zeros_like(sequences, device=sequences.device) - attention_mask_rm = torch.zeros_like(sequences, device=sequences.device) - for i in range(sequences.size(0)): - sequence = sequences[i] - bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0] - eos_index = generation_end_index[i] - sequence_to_pad = sequence[bos_index:eos_index] - sequence_padded = F.pad( - sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id - ) - input_ids_rm[i] = sequence_padded - if sequence_length - sequence_to_pad.size(0) > 0: - attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1 + # Left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input else: - attention_mask_rm[i, :] = 1 - attention_mask_rm = attention_mask_rm.to(dtype=torch.bool) + # stop_token_ids are given, generation ends with stop_token_ids + action_mask = torch.ones_like(sequences, dtype=torch.bool) + for i in range(sequences.size(0)): + stop_token_pos = [ + find_first_occurrence_subsequence( + sequences[i][input_len:], torch.tensor(stop_token_id).to(sequences.device) + ) + for stop_token_id in stop_token_ids + ] + stop_index = min([i for i in stop_token_pos if i != -1], default=-1) + stop_token_id = stop_token_ids[stop_token_pos.index(stop_index)] + if stop_index == -1: + # Sequence does not contain stop_token_ids, this should never happen BTW + logger.warning( + "Generated sequence does not contain stop_token_ids. Please check your chat template config" + ) + print(self.tokenizer.decode(sequences[i], skip_special_tokens=True)) + else: + # Keep stop tokens + stop_index = input_len + stop_index + action_mask[i, stop_index + len(stop_token_id) :] = False - r = self.reward_model( - input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), - attention_mask=attention_mask_rm.to(device=sequences.device), - ) + generation_end_index = (action_mask == True).sum(dim=-1) - 1 + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + action_mask = action_mask[:, -(sequences.size(1) - input_len) :] + num_actions = action_mask.size(1) + torch.cuda.empty_cache() + with torch.inference_mode(): + actor_output = [] + base_model_output = [] + for i in range(0, sequences.size(0), self.logits_forward_batch_size): + actor_output.append( + self.actor( + input_ids=sequences[i : i + self.logits_forward_batch_size], + attention_mask=attention_mask[i : i + self.logits_forward_batch_size], + use_cache=False, + )["logits"] + ) + base_model_output.append( + self.initial_model( + input_ids=sequences[i : i + self.logits_forward_batch_size], + attention_mask=attention_mask[i : i + self.logits_forward_batch_size], + use_cache=False, + )["logits"] + ) + actor_output = torch.cat(actor_output, dim=0) + base_model_output = torch.cat(base_model_output, dim=0) + action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) + base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) - value = self.critic( - input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), - attention_mask=attention_mask_rm.to(device=sequences.device), - ) - reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) - value = value[:, -num_actions:] * action_mask - advantages = self.calculate_advantage(value, reward, num_actions) + # Convert to right padding for the reward model and the critic model + input_ids_rm = torch.zeros_like(sequences, device=sequences.device) + response_start = [] + response_end = [] + attention_mask_rm = torch.zeros_like(sequences, device=sequences.device) + for i in range(sequences.size(0)): + sequence = sequences[i] + bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0] + eos_index = generation_end_index[i] + 1 # include the stop token + sequence_to_pad = sequence[bos_index:eos_index] + response_start.append(input_len - bos_index) + response_end.append(eos_index - bos_index) + sequence_padded = F.pad( + sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id + ) + input_ids_rm[i] = sequence_padded + if sequence_length - sequence_to_pad.size(0) > 0: + attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1 + else: + attention_mask_rm[i, :] = 1 + attention_mask_rm = attention_mask_rm.to(dtype=torch.bool) - advantages = advantages.detach() - value = value.detach() + r = self.reward_model( + input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), + attention_mask=attention_mask_rm.to(device=sequences.device), + response_start=response_start, + response_end=response_end, + gt_answer=gt_answer[s:e], + ) + + batch_sequences.append(sequences) + batch_input_ids_rm.append(input_ids_rm) + batch_attention_mask_rm.append(attention_mask_rm) + batch_attention_mask.append(attention_mask) + batch_r.append(r) + batch_action_log_probs.append(action_log_probs.cpu()) + batch_base_action_log_probs.append(base_action_log_probs.cpu()) + batch_action_mask.append(action_mask) + + sequences = torch.cat(batch_sequences, dim=0) + input_ids_rm = torch.cat(batch_input_ids_rm, dim=0) + attention_mask_rm = torch.cat(batch_attention_mask_rm, dim=0) + attention_mask = torch.cat(batch_attention_mask, dim=0) + r = torch.cat(batch_r, dim=0) + action_log_probs = torch.cat(batch_action_log_probs, dim=0).to(sequences.device) + base_action_log_probs = torch.cat(batch_base_action_log_probs, dim=0).to(sequences.device) + action_mask = torch.cat(batch_action_mask, dim=0).to(sequences.device) + if not self.use_grpo: + value = self.critic( + input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), + attention_mask=attention_mask_rm.to(device=sequences.device), + ) + value = value[:, -num_actions:] * action_mask + reward, kl = compute_reward( + r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask + ) + advantages = self.calculate_advantage(value, reward, num_actions) + advantages = advantages.detach() + value = value.detach() + else: + # GRPO advantage calculation + kl = torch.sum( + -self.kl_coef * (action_log_probs - base_action_log_probs) * action_mask, dim=-1 + ) / torch.sum( + action_mask, dim=-1 + ) # address numerical instability issue + r = kl + r + mean_gr = r.view(-1, self.num_generation).mean(dim=1) + std_gr = r.view(-1, self.num_generation).std(dim=1) + mean_gr = mean_gr.repeat_interleave(self.num_generation, dim=0) + std_gr = std_gr.repeat_interleave(self.num_generation, dim=0) + advantages = (r - mean_gr) / (std_gr + 1e-4) + value = r.detach() # dummy value r = r.detach() - - return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask) + return Experience( + sequences.cpu(), + action_log_probs.cpu(), + value.cpu(), + r.cpu(), + kl.cpu(), + advantages.cpu(), + attention_mask.cpu(), + action_mask.cpu(), + ) diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py index fba0949e3..a804384d1 100755 --- a/applications/ColossalChat/coati/models/__init__.py +++ b/applications/ColossalChat/coati/models/__init__.py @@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_ from .lora import LoraConfig, convert_to_lora_module, lora_manager from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from .reward_model import RewardModel +from .rlvr_reward_model import RLVRRewardModel from .utils import disable_dropout __all__ = [ "BaseModel", "Critic", "RewardModel", + "RLVRRewardModel", "PolicyLoss", "ValueLoss", "LogSigLoss", diff --git a/applications/ColossalChat/coati/models/generation.py b/applications/ColossalChat/coati/models/generation.py index b671ef124..e5ed07a7e 100755 --- a/applications/ColossalChat/coati/models/generation.py +++ b/applications/ColossalChat/coati/models/generation.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Callable, List, Optional import torch @@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict: return model_kwargs -def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict: +def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict: model_kwargs["input_ids"] = input_ids return model_kwargs def _sample( model: Any, + tokenizer: Any, input_ids: torch.Tensor, max_length: int, early_stopping: bool = True, @@ -137,8 +139,8 @@ def _sample( if max_new_tokens is None: max_new_tokens = max_length - context_length if context_length + max_new_tokens > max_length or max_new_tokens == 0: + print("Exeeded length limitation") return input_ids - logits_processor = _prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) past = None @@ -183,18 +185,14 @@ def _sample( if stop_token_ids is not None: # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished. - tokens_to_check = input_ids[:, -len(stop_token_ids) :] - unfinished_sequences = unfinished_sequences.mul( - torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long() - ) + for stop_token_id in stop_token_ids: + tokens_to_check = input_ids[:, -len(stop_token_id) :] + unfinished_sequences = unfinished_sequences.mul( + torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long() + ) # Stop when each sentence is finished if early_stopping=True if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1: - if i == context_length + max_new_tokens - 1: - # Force to end with stop token ids - input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = ( - torch.LongTensor(stop_token_ids).to(input_ids.device).long() - ) return input_ids @@ -237,8 +235,10 @@ def generate( raise NotImplementedError elif is_sample_gen_mode: # Run sample + generation_kwargs = copy.deepcopy(model_kwargs) res = _sample( model, + tokenizer, input_ids, max_length, early_stopping=early_stopping, @@ -249,8 +249,9 @@ def generate( temperature=temperature, prepare_inputs_fn=prepare_inputs_fn, update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs, + **generation_kwargs, ) + del generation_kwargs return res elif is_beam_gen_mode: raise NotImplementedError @@ -350,11 +351,17 @@ def _sample_streaming( unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) if stop_token_ids is not None: - # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished. tokens_to_check = input_ids[:, -len(stop_token_ids) :] - unfinished_sequences = unfinished_sequences.mul( - torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long() - ) + if isinstance(stop_token_ids[0], int): + # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished. + unfinished_sequences = unfinished_sequences.mul( + torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long() + ) + else: + for stop_token_id in stop_token_ids: + unfinished_sequences = unfinished_sequences.mul( + torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long() + ) # Stop when each sentence is finished if early_stopping=True if ( diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py index 573b9d889..b2e6601ea 100755 --- a/applications/ColossalChat/coati/models/reward_model.py +++ b/applications/ColossalChat/coati/models/reward_model.py @@ -25,7 +25,9 @@ class RewardModel(BaseModel): self.value_head = nn.Linear(self.last_hidden_state_size, 1) self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1)) - def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: outputs = self.model(input_ids, attention_mask=attention_mask) last_hidden_states = outputs["last_hidden_state"] diff --git a/applications/ColossalChat/coati/models/rlvr_reward_model.py b/applications/ColossalChat/coati/models/rlvr_reward_model.py new file mode 100644 index 000000000..13c463691 --- /dev/null +++ b/applications/ColossalChat/coati/models/rlvr_reward_model.py @@ -0,0 +1,50 @@ +""" +reward model +""" + +from typing import Callable, List, Optional + +import torch + + +class RLVRRewardModel: + """ + RLVRReward model class. Support varifiable reward. + + Args: + reward_fn_list List: list of reward functions + **kwargs: all other kwargs as in reward functions + """ + + def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None: + self.reward_fn_list = reward_fn_list + self.kwargs = kwargs + + def __call__( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + response_start: List = None, + response_end: List = None, + gt_answer: List = None, + ) -> torch.Tensor: + # apply varifiable reward + bs = input_ids.size(0) + rewards = torch.zeros(bs, device=input_ids.device) + for i in range(bs): + for reward_fn in self.reward_fn_list: + rewards[i] += reward_fn( + input_ids[i], + attention_mask[i], + response_start=response_start[i], + response_end=response_end[i], + gt_answer=gt_answer[i], + **self.kwargs, + ) + return rewards + + def to(self, device): + return self + + def eval(self): + return self diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py index fe7ab2098..e7a7e7162 100755 --- a/applications/ColossalChat/coati/models/utils.py +++ b/applications/ColossalChat/coati/models/utils.py @@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module): for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0.0 + + +def repad_to_left(tensor, tokenizer): + repadded_input_ids = [] + max_non_padded_seq_len = 0 + for i in range(tensor.size(0)): + non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0] + start, end = non_pad_indices.min(), non_pad_indices.max() + repadded_input_ids.append(tensor[i][start : end + 1]) + max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0)) + repadded_input_ids = [ + F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids + ] + return torch.stack(repadded_input_ids) diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py index 6d0900153..c1bad9ed1 100755 --- a/applications/ColossalChat/coati/trainer/__init__.py +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -1,5 +1,6 @@ from .base import OLTrainer, SLTrainer from .dpo import DPOTrainer +from .grpo import GRPOTrainer from .kto import KTOTrainer from .orpo import ORPOTrainer from .ppo import PPOTrainer @@ -15,4 +16,5 @@ __all__ = [ "DPOTrainer", "ORPOTrainer", "KTOTrainer", + "GRPOTrainer", ] diff --git a/applications/ColossalChat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py index bef4ccc3e..a871798eb 100755 --- a/applications/ColossalChat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -96,6 +96,7 @@ class OLTrainer(ABC): self.sample_buffer = sample_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks + self.num_train_step = 0 @contextmanager def _fit_ctx(self) -> None: @@ -212,5 +213,6 @@ class OLTrainer(ABC): self._update_phase(update_step) # NOTE: this is for on-policy algorithms self.data_buffer.clear() - if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0: - self._save_checkpoint(episode + 1) + + if self.num_train_step > 0 and (self.num_train_step + 1) % (self.save_interval) == 0: + self._save_checkpoint(self.num_train_step + 1) diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 499113e96..cde13d41e 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -343,7 +343,7 @@ class DPOTrainer(SLTrainer): self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) - if (i + 1) % self.accumulation_steps == 0: + if (self.num_train_step + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.actor_scheduler.step() @@ -358,26 +358,27 @@ class DPOTrainer(SLTrainer): ) step_bar.update() if self.writer and is_rank_0(): - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + global_step = (self.num_train_step + 1) / self.accumulation_steps + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step) self.writer.add_scalar( - "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step ) self.writer.add_scalar( "train/rejected_rewards", self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/margin", self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/accuracy", self.accumulative_meter.get("accuracy"), - self.num_train_step, + global_step, ) self.num_train_step += 1 self.accumulative_meter.reset() diff --git a/applications/ColossalChat/coati/trainer/grpo.py b/applications/ColossalChat/coati/trainer/grpo.py new file mode 100755 index 000000000..08710b196 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/grpo.py @@ -0,0 +1,386 @@ +""" +GRPO trainer +""" + +import os +from typing import Dict, List, Optional, Union + +import torch +import wandb +from coati.experience_buffer import NaiveExperienceBuffer +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models import RewardModel, RLVRRewardModel +from coati.models.loss import GPTLMLoss, PolicyLoss +from coati.models.utils import calc_action_log_probs +from coati.trainer.callbacks import Callback +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import OLTrainer +from .utils import AnnealingScheduler, CycledDataLoader, is_rank_0, to_device + + +def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict: + """ + Set default keyword arguments for generation based on the actor model. + + Args: + actor (PreTrainedModel): The actor model. + + Returns: + Dict: A dictionary containing the default keyword arguments for generation. + """ + unwrapped_model = actor.unwrap() + new_kwargs = {} + # use huggingface models method directly + if hasattr(unwrapped_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation + if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation + return new_kwargs + + +class GRPOTrainer(OLTrainer): + """ + Trainer for GRPO algorithm. + + Args: + strategy (Booster): the strategy to use for training + actor (Actor): the actor model in ppo algorithm + reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences + initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor + actor_optim (Optimizer): the optimizer to use for actor model + kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitation of buffer + buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + vf_coef (float, defaults to 1.0): the coefficient of value loss + ptx_coef (float, defaults to 0.9): the coefficient of ptx loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + sample_buffer (bool, defaults to False): whether to sample from buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__( + self, + actor_booster: Booster, + actor: PreTrainedModel, + reward_model: Union[RewardModel, RLVRRewardModel], + initial_model: PreTrainedModel, + actor_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.2, + sample_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + apply_loss_mask: bool = True, + accumulation_steps: int = 1, + save_interval: int = 0, + save_dir: str = None, + use_tp: bool = False, + num_generation: int = 8, + inference_batch_size: int = None, + logits_forward_batch_size: int = None, + temperature_annealing_config: Optional[Dict] = None, + coordinator: DistCoordinator = None, + callbacks: List[Callback] = [], + **generate_kwargs, + ) -> None: + if isinstance(actor_booster, GeminiPlugin): + assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" + + data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + super().__init__(actor_booster, None, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks) + self.generate_kwargs = _set_default_generate_kwargs(actor) + self.generate_kwargs.update(generate_kwargs) + + self.actor = actor + self.actor_booster = actor_booster + self.actor_scheduler = actor_lr_scheduler + self.tokenizer = tokenizer + self.experience_maker = NaiveExperienceMaker( + self.actor, + None, + reward_model, + initial_model, + self.tokenizer, + kl_coef, + use_grpo=True, + num_generation=num_generation, + inference_batch_size=inference_batch_size, + logits_forward_batch_size=logits_forward_batch_size, + ) + if temperature_annealing_config: + # use annealing + self.temperature_annealing_scheduler = AnnealingScheduler( + temperature_annealing_config["start_temperature"], + temperature_annealing_config["end_temperature"], + temperature_annealing_config["annealing_warmup_steps"], + temperature_annealing_config["annealing_steps"], + ) + else: + self.temperature_annealing_scheduler = None + + self.train_batch_size = train_batch_size + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.vf_coef = vf_coef + self.ptx_loss_fn = GPTLMLoss() + self.ptx_coef = ptx_coef + self.actor_optim = actor_optim + self.save_interval = save_interval + self.apply_loss_mask = apply_loss_mask + self.coordinator = coordinator + self.actor_save_dir = os.path.join(save_dir, "actor") + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.use_tp = use_tp + self.accumulative_meter = AccumulativeMeanMeter() + self.offload_inference_models = offload_inference_models + self.device = get_current_device() + + def _before_fit( + self, + prompt_dataloader: DataLoader, + pretrain_dataloader: Optional[DataLoader] = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.prompt_dataloader = CycledDataLoader(prompt_dataloader) + self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-grpo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "grpo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _setup_update_phrase_dataload(self): + """ + why not use distributed_dataloader? + if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks + if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank + """ + self.dataloader = DataLoader( + self.data_buffer, + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True, + pin_memory=self.dataloader_pin_memory, + collate_fn=self.data_buffer.collate_fn, + ) + + def _make_experience(self, collect_step: int) -> Experience: + """ + Make experience + """ + prompts = self.prompt_dataloader.next() + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) + if self.temperature_annealing_scheduler: + self.generate_kwargs["temperature"] = self.temperature_annealing_scheduler.get_temperature() + return self.experience_maker.make_experience( + input_ids=prompts["input_ids"].to(get_current_device()), + attention_mask=prompts["attention_mask"].to(get_current_device()), + gt_answer=prompts["gt_answer"], + **self.generate_kwargs, + ) + + def _training_step(self, experience: Experience): + """ + Args: + experience: + sequences: [batch_size, prompt_length + response_length] --- ............ + """ + self.num_train_step += 1 + self.actor.train() + num_actions = experience.action_log_probs.size(1) + # policy loss + + actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[ + "logits" + ] # [batch size, prompt_length + response_length] + action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) + actor_loss, to_skip, max_ratio = self.actor_loss_fn( + action_log_probs, + experience.action_log_probs, + experience.advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + action_mask=experience.action_mask if self.apply_loss_mask else None, + ) + # sequence that is not end properly are not counted in token cost + token_cost = torch.sum( + (experience.sequences[:, -num_actions:] != self.tokenizer.pad_token_id).to(torch.float), axis=-1 + ).to(actor_logits.device) + end_properly = experience.sequences[:, -1] == self.tokenizer.pad_token_id + mean_token_cost = torch.sum(token_cost * end_properly) / torch.sum(end_properly) + actor_loss = (1 - self.ptx_coef) * actor_loss + if not to_skip: + self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim) + + # ptx loss + if self.ptx_coef != 0: + batch = self.pretrain_dataloader.next() + batch = to_device(batch, self.device) + outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + ptx_loss = outputs.loss + ptx_loss = self.ptx_coef * ptx_loss + self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim) + + # sync + actor_loss_mean = all_reduce_mean(tensor=actor_loss) + max_ratio_mean = all_reduce_mean(tensor=max_ratio) + reward_mean = all_reduce_mean(tensor=experience.reward.mean()) + advantages_mean = all_reduce_mean(tensor=experience.advantages.mean()) + kl_mean = all_reduce_mean(tensor=experience.kl.mean()) + mean_token_cost = all_reduce_mean(tensor=mean_token_cost) + if self.ptx_coef != 0: + ptx_loss_mean = all_reduce_mean(tensor=ptx_loss) + + self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item()) + self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item()) + self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0) + self.accumulative_meter.add("mean_token_cost", mean_token_cost.to(torch.float16).item()) + self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item()) + if self.ptx_coef != 0: + self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item()) + + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.actor_optim.step() + self.actor_optim.zero_grad() + self.actor_scheduler.step() + + if self.temperature_annealing_scheduler: + self.temperature_annealing_scheduler.step_forward() + + # preparing logging model output and corresponding rewards. + if self.num_train_step % 10 == 1: + response_text = self.experience_maker.tokenizer.batch_decode( + experience.sequences, skip_special_tokens=True + ) + for i in range(len(response_text)): + response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}" + + if self.writer and is_rank_0() and "wandb_run" in self.__dict__: + # log output to wandb + my_table = wandb.Table( + columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text] + ) + try: + self.wandb_run.log({"sample_response": my_table}) + except OSError as e: + self.coordinator.print_on_master(e) + elif self.writer and is_rank_0(): + for line in response_text: + self.coordinator.print_on_master(line) + + if self.writer and is_rank_0(): + global_step = (self.num_train_step + 1) / self.accumulation_steps + self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), global_step) + self.writer.add_scalar("train/skip_ratio", self.accumulative_meter.get("skip_ratio"), global_step) + self.writer.add_scalar("train/actor_loss", self.accumulative_meter.get("actor_loss"), global_step) + self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], global_step) + if self.ptx_coef != 0: + self.writer.add_scalar("train/ptx_loss", self.accumulative_meter.get("ptx_loss"), global_step) + self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), global_step) + self.writer.add_scalar("token_cost", self.accumulative_meter.get("mean_token_cost"), global_step) + self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step) + self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step) + self.accumulative_meter.reset() + + def _learn(self, update_step: int): + """ + Perform the learning step of the PPO algorithm. + + Args: + update_step (int): The current update step. + + Returns: + None + """ + if self.offload_inference_models: + self.experience_maker.initial_model.to("cpu") + self.experience_maker.reward_model.to("cpu") + # buffer may be empty at first, we should rebuild at each training + if self.sample_buffer: + experience = self.data_buffer.sample() + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) + else: + if isinstance(self.dataloader.sampler, DistributedSampler): + self.dataloader.sampler.set_epoch(update_step) + pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) + + def _save_checkpoint(self, num_train_step: int = 0): + """ + Save the actor checkpoints with running states. + + Args: + num_train_step (int): The current num_train_step number. + + Returns: + None + """ + + self.coordinator.print_on_master("\nStart saving actor checkpoint with running states") + save_checkpoint( + save_dir=self.actor_save_dir, + booster=self.actor_booster, + model=self.actor, + optimizer=self.actor_optim, + lr_scheduler=self.actor_scheduler, + epoch=0, + step=num_train_step + 1, + batch_size=self.train_batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved actor checkpoint at episode {(num_train_step + 1)} at folder {self.actor_save_dir}" + ) diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index f0b23afb6..2d7e2fa85 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -217,25 +217,25 @@ class KTOTrainer(SLTrainer): self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item()) - if i % self.accumulation_steps == self.accumulation_steps - 1: - self.num_train_step += 1 + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: step_bar.update() # logging if self.writer and is_rank_0(): - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + global_step = (self.num_train_step + 1) / self.accumulation_steps + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step) self.writer.add_scalar( - "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step ) self.writer.add_scalar( "train/rejected_rewards", self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/margin", self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.accumulative_meter.reset() @@ -256,6 +256,7 @@ class KTOTrainer(SLTrainer): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" ) + self.num_train_step += 1 step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index 761fd305a..0224c8f34 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -184,35 +184,35 @@ class ORPOTrainer(SLTrainer): self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) - if i % self.accumulation_steps == self.accumulation_steps - 1: - self.num_train_step += 1 + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: step_bar.update() + global_step = (self.num_train_step + 1) / self.accumulation_steps # logging if self.writer and is_rank_0(): - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step) self.writer.add_scalar( - "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step ) self.writer.add_scalar( "train/rejected_rewards", self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/margin", self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/accuracy", self.accumulative_meter.get("accuracy"), - self.num_train_step, + global_step, ) self.writer.add_scalar( "train/log_odds_ratio", self.accumulative_meter.get("log_odds_ratio"), - self.num_train_step, + global_step, ) self.accumulative_meter.reset() @@ -233,6 +233,7 @@ class ORPOTrainer(SLTrainer): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" ) + self.num_train_step += 1 step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py index 63c813b39..331425174 100755 --- a/applications/ColossalChat/coati/trainer/ppo.py +++ b/applications/ColossalChat/coati/trainer/ppo.py @@ -3,13 +3,13 @@ PPO trainer """ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch import wandb from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models import Critic, RewardModel +from coati.models import Critic, RewardModel, RLVRRewardModel from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.models.utils import calc_action_log_probs from coati.trainer.callbacks import Callback @@ -84,7 +84,7 @@ class PPOTrainer(OLTrainer): critic_booster: Booster, actor: PreTrainedModel, critic: Critic, - reward_model: RewardModel, + reward_model: Union[RewardModel, RLVRRewardModel], initial_model: PreTrainedModel, actor_optim: Optimizer, critic_optim: Optimizer, @@ -210,6 +210,7 @@ class PPOTrainer(OLTrainer): return self.experience_maker.make_experience( input_ids=prompts["input_ids"].to(get_current_device()), attention_mask=prompts["attention_mask"].to(get_current_device()), + gt_answer=prompts["gt_answer"], **self.generate_kwargs, ) diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index 82e4625b9..991167a91 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -150,29 +150,29 @@ class RewardModelTrainer(SLTrainer): self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item()) - if (i + 1) % self.accumulation_steps == 0: + if (self.num_train_step + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.actor_scheduler.step() step_bar.update() - self.num_train_step += 1 # Logging if self.writer and is_rank_0(): - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + global_step = (self.num_train_step + 1) / self.accumulation_steps + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step) self.writer.add_scalar( "train/dist", self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), - self.num_train_step, + global_step, ) self.writer.add_scalar( - "train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + "train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), global_step ) self.writer.add_scalar( - "train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step + "train/reward_reject", self.accumulative_meter.get("rejected_rewards"), global_step ) - self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step) + self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), global_step) self.accumulative_meter.reset() @@ -193,6 +193,7 @@ class RewardModelTrainer(SLTrainer): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}" ) + self.num_train_step += 1 step_bar.close() def _eval(self, epoch): diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index 3aedcf7a9..fe7f4978b 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -143,15 +143,15 @@ class SFTTrainer(SLTrainer): self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) # Gradient accumulation - if (i + 1) % self.accumulation_steps == 0: + if (self.num_train_step + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() - + global_step = (self.num_train_step + 1) / self.accumulation_steps step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) if self.writer: - self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step) self.num_train_step += 1 self.accumulative_meter.reset() step_bar.update() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 217a87cf0..22a5f492e 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -12,6 +12,27 @@ from torch.utils.data import DataLoader from colossalai.booster import Plugin +class AnnealingScheduler: + def __init__(self, start, end, warmup_steps=100, annealing_step=2000): + self.start = start + self.end = end + self.warmup_steps = warmup_steps + self.step = 0 + self.annealing_step = annealing_step + + def get_temperature(self): + if self.step <= self.warmup_steps: + return self.start # Stop annealing after warm-up steps + elif self.step >= self.annealing_step: + return self.end + # Linear annealing + temp = self.start - (self.step / self.annealing_step) * (self.start - self.end) + return temp + + def step_forward(self): + self.step += 1 + + class CycledDataLoader: """ A data loader that cycles through the data when it reaches the end. diff --git a/applications/ColossalChat/coati/utils/reward_score/__init__.py b/applications/ColossalChat/coati/utils/reward_score/__init__.py new file mode 100644 index 000000000..2bc90b9eb --- /dev/null +++ b/applications/ColossalChat/coati/utils/reward_score/__init__.py @@ -0,0 +1,4 @@ +from .competition import math_competition_reward_fn +from .gsm8k import gsm8k_reward_fn + +__all__ = ["gsm8k_reward_fn", "math_competition_reward_fn"] diff --git a/applications/ColossalChat/coati/utils/reward_score/competition.py b/applications/ColossalChat/coati/utils/reward_score/competition.py new file mode 100644 index 000000000..60c869e14 --- /dev/null +++ b/applications/ColossalChat/coati/utils/reward_score/competition.py @@ -0,0 +1,26 @@ +import torch + +from .utils import extract_solution, validate_response_structure + + +def math_competition_reward_fn(input_ids, attention_mask, **kwargs): + # apply varifiable reward + # reward 10 points if the final answer is correct, reward 1 point if format is correct + + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward diff --git a/applications/ColossalChat/coati/utils/reward_score/gsm8k.py b/applications/ColossalChat/coati/utils/reward_score/gsm8k.py new file mode 100644 index 000000000..b82202c44 --- /dev/null +++ b/applications/ColossalChat/coati/utils/reward_score/gsm8k.py @@ -0,0 +1,31 @@ +import torch + +from .utils import extract_solution, validate_response_structure + + +def gsm8k_reward_fn(input_ids, attention_mask, **kwargs): + # apply varifiable reward + # reward 10 points if the final answer is correct, reward 1 point if format is correct + + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + is_valid = True + try: + int(final_answer.strip()) + except Exception: + is_valid = False + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not is_valid or not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward diff --git a/applications/ColossalChat/coati/utils/reward_score/utils.py b/applications/ColossalChat/coati/utils/reward_score/utils.py new file mode 100644 index 000000000..c1e73d4b9 --- /dev/null +++ b/applications/ColossalChat/coati/utils/reward_score/utils.py @@ -0,0 +1,76 @@ +# Copyright Unakar +# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Optional, Tuple + + +def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + validation_passed = True + # Check required tags + if tags is None: + tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + positions = {} + for tag_name, tag_info in tags.items(): + tag_str = tag_info["text"] + expected_count = tag_info["num_occur"] + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + if count != expected_count: + validation_passed = False + # Verify tag order + if ( + positions["think_start"] > positions["think_end"] + or positions["think_end"] > positions["answer_start"] + or positions["answer_start"] > positions["answer_end"] + ): + validation_passed = False + if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]): + validation_passed = False + return validation_passed + + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + + # Extract final answer using XML-style tags + answer_pattern = r"(.*?)" + matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) + + if not matches: + return None, solution_str + + final_answer = matches[-1].group(1).strip() + return final_answer, solution_str diff --git a/applications/ColossalChat/conversation_template/MiniCPM-2b.json b/applications/ColossalChat/conversation_template/MiniCPM-2b.json new file mode 100644 index 000000000..2fb1c870c --- /dev/null +++ b/applications/ColossalChat/conversation_template/MiniCPM-2b.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 122753 + ], + "end_of_assistant": "<|im_end|>" +} diff --git a/applications/ColossalChat/conversation_template/Qwen_Qwen2.5-3B.json b/applications/ColossalChat/conversation_template/Qwen_Qwen2.5-3B.json new file mode 100644 index 000000000..9f9c9020f --- /dev/null +++ b/applications/ColossalChat/conversation_template/Qwen_Qwen2.5-3B.json @@ -0,0 +1,26 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., 123 .\n", + "stop_ids": [ + 151643 + ], + "end_of_assistant": "<|endoftext|>", + "response_format_tags": { + "think_start": { + "text": "", + "num_occur": 1 + }, + "think_end": { + "text": "", + "num_occur": 1 + }, + "answer_start": { + "text": "", + "num_occur": 1 + }, + "answer_end": { + "text": "", + "num_occur": 1 + } + } +} diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index fec7bc061..fb6438a09 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -27,6 +27,7 @@ - [Reward](#reward) - [KL Divergence](#approximate-kl-divergence) - [Note on PPO Training](#note-on-ppo-training) + - [GRPO Training and DeepSeek R1 reproduction] - [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization) - [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning) - [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training) @@ -725,6 +726,75 @@ Answer: The causes of this problem are two-fold. Check your reward model, make s #### Q4: Generation is garbage Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss. +## GRPO Training and DeepSeek R1 reproduction +We support GRPO (Group Relative Policy Optimization), which is the reinforcement learning algorithm used in DeepSeek R1 paper. In this section, we will walk through GRPO training with an example trying to reproduce Deepseek R1's results in mathematical problem solving. + +### GRPO Model Selection +We finally select the base version of [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B). We also did experiments on the instruct version [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) but the later one fails to explore more diversed output. We recommend to use base models (without SFT) and use a few SFT steps (see [SFT section](#rlhf-training-stage1---supervised-instructs-tuning)) to correct the base model's output format before GRPO. + +### Reinforcement Learning with Verifiable Reward +Both the PPO and the GRPO support reinforcement learning with verifiable reward (RLVR). In this experiment on mathematical problem solving, we define the reward function as following, in the following definition, forward is correct if there are exactly one pair of , tags in the response and the order of the tags is correct. + +- reward=0, if format is incorrect. +- reward=1, if format is correct but the answer doesn't match the ground truth answer exactly. +- reward=10, if format is correct and the answer match the ground truth answer exactly. + +### Step 1: Data Collection & Preparation +For GPRO training, you only need the prompt dataset. Please follow the instruction in the [prompt dataset preparation](#rlhf-training-stage3---proximal-policy-optimization) to prepare the prompt data for GPRO training. In our reproduction experiment, we use the [qwedsacf/competition_math dataset](https://huggingface.co/datasets/qwedsacf/competition_math), which is available on Huggingface. + +### Step 2: Training +You can run the [train_grpo.sh](./training_scripts/train_grpo.sh) to start GRPO training. The script share most of its arguments with the PPO script (please refer to the [PPO training section](#step-3-training) for more details). Here are some unique arguments for GRPO. + +```bash +--num_generations 8 \ # number of roll outs to collect for each prompt +--inference_batch_size 8 \ # batch size used during roll out +--logits_forward_batch_size 1 \ # batch size used to calculate logits for GRPO training +--initial_temperature \ # initial temperature for annealing algorithm +--final_temperature \ # final temperature for annealing algorithm +``` + +As the GRPO requires to collect a group of response from each prompt (usually greater than 8), the effective batch size will satisfy the following constraints, + +- Without tensor parallelism, +``` +experience buffer size += num_process * num_collect_steps * experience_batch_size * num_generations += train_batch_size * accumulation_steps * num_process +``` + +- With tensor parallelism, +``` +num_tp_group = num_process / tp +experience buffer size += num_tp_group * num_collect_steps * experience_batch_size * num_generations += train_batch_size * accumulation_steps * num_tp_group +``` + +During roll out, we perform rebatching to prevent out of memory both before roll out and before calculating logits. Please choose a proper setting for the "inference_batch_size" and the "logits_forward_batch_size" based on your device. + +### GRPO Result +#### Reward +

+image +

+ +#### Response Length +

+image +

+ +#### Response Length Distribution (After Training) +

+image +

+ +#### Sample Response +

+image +

+ +#### Note of Speed +Currently, our PPO and GRPO pipeline are still under development. The speed is largely limited by the roll out speed as we use naive generation without any acceleration. ## Alternative Option For RLHF: Direct Preference Optimization diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh index d74667889..f0f0714dd 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh @@ -11,4 +11,4 @@ python prepare_dataset.py --type prompt \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ --data_arrow_output_dir $SAVE_DIR/arrow \ - --max_length 1024 + --max_length 300 diff --git a/applications/ColossalChat/examples/requirements.txt b/applications/ColossalChat/examples/requirements.txt index 91f25a5cf..31eef5256 100644 --- a/applications/ColossalChat/examples/requirements.txt +++ b/applications/ColossalChat/examples/requirements.txt @@ -1,4 +1,4 @@ pandas>=1.4.1 sentencepiece -colossalai==0.4.0 +colossalai==0.4.7 prompt_toolkit diff --git a/applications/ColossalChat/examples/training_scripts/train_grpo.py b/applications/ColossalChat/examples/training_scripts/train_grpo.py new file mode 100755 index 000000000..6acdbebb1 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_grpo.py @@ -0,0 +1,494 @@ +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from coati.dataset import ( + DataCollatorForPromptDataset, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_conversation_template, +) +from coati.models import LoraConfig, RewardModel, RLVRRewardModel, convert_to_lora_module, disable_dropout, lora_manager +from coati.trainer import GRPOTrainer +from coati.utils import load_checkpoint +from coati.utils.reward_score import * +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +logger = get_dist_logger() +# default settings for response format tags, overwrite it in chat_template definition if needed +response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, +} + + +def train(args): + global response_format_tags + lora_config = None + if args.lora_config is not None: + lora_config = LoraConfig.from_file(args.lora_config) + # check lora compatibility + if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + with init_ctx: + if args.use_flash_attn: + actor = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + local_files_only=True, + trust_remote_code=True, + ) + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + local_files_only=True, + trust_remote_code=True, + ) + if args.rm_pretrain: + reward_model = RewardModel( + args.rm_pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + trust_remote_code=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True) + if args.rm_pretrain: + reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True) + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, local_files_only=True, trust_remote_code=True + ) + + if args.lora_config is not None: + actor = convert_to_lora_module(actor, lora_config=lora_config) + for name, module in actor.named_modules(): + if "norm" in name or "gate" in name: + module = module.to(torch.float32) + lora_manager.able_to_merge = False + + # Disable dropout + disable_dropout(actor) + + if args.grad_checkpoint: + actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if os.path.exists(args.conversation_template_config): + with open(args.conversation_template_config, "r", encoding="utf8") as f: + conversation_template_config = json.load(f) + dist.barrier() + if "response_format_tags" in conversation_template_config: + logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}") + response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags) + conversation_template = setup_conversation_template( + tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config + ) + stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None + else: + raise ValueError("Conversation template config is not provided or incorrect") + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + tokenizer.padding_side = "left" # left padding for generation (online learning) + + # configure generation config + actor.generation_config.update( + pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id + ) + + # configure optimizer + coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}") + actor_optim = HybridAdam( + model_params=actor.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + if args.warmup_steps is None: + args.warmup_steps = int(0.025 * args.num_episodes) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + actor_lr_scheduler = CosineAnnealingWarmupLR( + optimizer=actor_optim, + total_steps=args.num_episodes, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism): + logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.") + args.use_flash_attn = False + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + if args.rm_pretrain: + custom_plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + custom_policy=get_autopolicy(reward_model.model), + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + if args.plugin != "3d" and args.rm_pretrain: + custom_plugin = plugin + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map) + + data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len) + + train_prompt_dataloader = plugin.prepare_dataloader( + dataset=train_prompt_dataset, + batch_size=args.experience_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + + if len(args.ptx_dataset) > 0: + train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + train_pretrain_dataloader = plugin.prepare_dataloader( + dataset=train_ptx_dataset, + batch_size=args.ptx_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + else: + train_pretrain_dataloader = None + + actor_booster = Booster(plugin=plugin) + ref_booster = Booster(plugin=plugin) + if args.rm_pretrain: + rm_booster = Booster(plugin=custom_plugin) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost( + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + dataloader=train_prompt_dataloader, + ) + if args.rm_pretrain: + reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader) + else: + if args.reward_functions: + reward_fn_list = [] + for reward_fn in args.reward_functions: + """ + To define custom reward function, you can define your functions under: + colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py + and use it here by mofiying the following line: + """ + if reward_fn == "gsm8k_reward_fn": + reward_fn_list.append(gsm8k_reward_fn) + elif reward_fn == "math_competition_reward_fn": + reward_fn_list.append(math_competition_reward_fn) + else: + raise ValueError(f"Unknown reward function {reward_fn}") + reward_model = RLVRRewardModel( + reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags + ) + + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader) + + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + sampler_start_idx = 0 + start_step = 0 + + if args.rm_checkpoint_path is not None: + if "modeling" in args.rm_checkpoint_path: + rm_booster.load_model(reward_model, args.rm_checkpoint_path) + else: + _, _, _ = load_checkpoint( + load_dir=args.rm_checkpoint_path, + booster=rm_booster, + model=reward_model, + optimizer=None, + lr_scheduler=None, + ) + coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}") + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + actor_booster.load_model(actor, args.checkpoint_path) + ref_booster.load_model(ref_model, args.checkpoint_path) + coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}") + else: + _, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=actor_booster, + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + ) + _, _, _ = load_checkpoint(load_dir=args.checkpoint_path, booster=ref_booster, model=ref_model) + assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler) + train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + # configure trainer + trainer = GRPOTrainer( + actor_booster, + actor, + reward_model, + ref_model, + actor_optim, + actor_lr_scheduler, + tokenizer=tokenizer, + stop_token_ids=[stop_ids], + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + train_batch_size=args.train_batch_size, + buffer_limit=args.num_collect_steps * args.experience_batch_size * args.num_generations, + max_length=args.max_length, + use_cache=True, + do_sample=True, + apply_loss_mask=not args.disable_loss_mask, + accumulation_steps=args.accumulation_steps, + save_dir=args.save_path, + save_interval=args.save_interval, + top_k=50, + use_tp=args.tp > 1, + num_generations=args.num_generations, + inference_batch_size=args.inference_batch_size, + logits_forward_batch_size=args.logits_forward_batch_size, + offload_inference_models="gemini" not in args.plugin, + coordinator=coordinator, + max_tokens_thinking=args.max_tokens_thinking if args.max_tokens_thinking else args.max_length - 100, + temperature_annealing_config={ + "start_temperature": args.initial_temperature, + "end_temperature": args.final_temperature, + "annealing_warmup_steps": min(100, int(args.num_episodes / 6)), + "annealing_steps": min(600, int(args.num_episodes / 2)), + }, + # Hack: some old model's default update_model_kwargs_fn/prepare_inputs_fn may doesn't work due to version conflict with transformers, you can overwrite them + # update_model_kwargs_fn=update_model_kwargs_fn, + # prepare_inputs_fn = None + ) + + trainer.fit( + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + prompt_dataloader=train_prompt_dataloader, + pretrain_dataloader=train_pretrain_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if lora_config is not None and lora_config.r > 0: + # NOTE: set model to eval to merge LoRA weights + lora_manager.able_to_merge = True + actor.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final actor model checkpoint") + actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}" + ) + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt_dataset", nargs="+", default=[]) + parser.add_argument("--ptx_dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument( + "--conversation_template_config", + type=str, + default=None, + help="Path \ + to save conversation template config files.", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) + parser.add_argument("--zero_cpu_offload", default=False, action="store_true") + parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path") + parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--num_collect_steps", type=int, default=2) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--num_generations", type=int, default=8) + parser.add_argument("--inference_batch_size", type=int, default=None) + parser.add_argument("--save_interval", type=int, default=1000) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--logits_forward_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=16) + parser.add_argument("--ptx_batch_size", type=int, default=4) + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--lr", type=float, default=1e-6) + parser.add_argument("--kl_coef", type=float, default=0.7) + parser.add_argument("--ptx_coef", type=float, default=0.0) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") + parser.add_argument("--max_length", type=int, default=2048) + parser.add_argument("--max_tokens_thinking", type=int, default=2000) + parser.add_argument("--max_seq_len", type=int, default=256) + parser.add_argument("--initial_temperature", type=float, default=1.0) + parser.add_argument("--final_temperature", type=float, default=0.9) + parser.add_argument("--log_dir", default=None, type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + + args = parser.parse_args() + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_grpo.sh b/applications/ColossalChat/examples/training_scripts/train_grpo.sh new file mode 100755 index 000000000..28c89ffbd --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_grpo.sh @@ -0,0 +1,86 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PROJECT_NAME="PPO-RLVR" + +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT) +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file +LOGDIR="" + +declare -a prompt_dataset=( + YOUR/PROMPT/DATA/DIR/arrow/part-00000 + YOUR/PROMPT/DATA/DIR/arrow/part-00001 + YOUR/PROMPT/DATA/DIR/arrow/part-00002 + YOUR/PROMPT/DATA/DIR/arrow/part-00003 + YOUR/PROMPT/DATA/DIR/arrow/part-00004 + YOUR/PROMPT/DATA/DIR/arrow/part-00005 + YOUR/PROMPT/DATA/DIR/arrow/part-00006 + YOUR/PROMPT/DATA/DIR/arrow/part-00007 + YOUR/PROMPT/DATA/DIR/arrow/part-00008 + YOUR/PROMPT/DATA/DIR/arrow/part-00009 +) + +declare -a ptx_dataset=( + YOUR/SFT/DATA/DIR/arrow/part-00000 + YOUR/SFT/DATA/DIR/arrow/part-00001 + YOUR/SFT/DATA/DIR/arrow/part-00002 + YOUR/SFT/DATA/DIR/arrow/part-00003 + YOUR/SFT/DATA/DIR/arrow/part-00004 + YOUR/SFT/DATA/DIR/arrow/part-00005 + YOUR/SFT/DATA/DIR/arrow/part-00006 + YOUR/SFT/DATA/DIR/arrow/part-00007 + YOUR/SFT/DATA/DIR/arrow/part-00008 + YOUR/SFT/DATA/DIR/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --num_nodes 1 --hostfile ./hostfile train_grpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --prompt_dataset ${prompt_dataset[@]} \ + --conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \ + --ptx_coef 0.0 \ + --plugin "zero2_cpu" \ + --reward_functions math_competition_reward_fn \ + --save_interval 250 \ + --save_path $SAVE_DIR \ + --num_episodes 100 \ + --num_collect_steps 8 \ + --num_update_steps 1 \ + --experience_batch_size 1 \ + --train_batch_size 4 \ + --inference_batch_size 8 \ + --logits_forward_batch_size 2 \ + --accumulation_steps 4 \ + --lr 1e-6 \ + --mixed_precision "bf16" \ + --grad_clip 0.1\ + --weight_decay 0.01 \ + --kl_coef 0.01 \ + --warmup_steps 40 \ + --max_length 2000 \ + --max_seq_len 1700 \ + --log_dir $LOGDIR \ + --use_flash_attn \ + --grad_checkpoint diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py index a0a10e239..4c4f31087 100755 --- a/applications/ColossalChat/examples/training_scripts/train_ppo.py +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -13,9 +13,18 @@ from coati.dataset import ( load_tokenized_dataset, setup_conversation_template, ) -from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager +from coati.models import ( + Critic, + LoraConfig, + RewardModel, + RLVRRewardModel, + convert_to_lora_module, + disable_dropout, + lora_manager, +) from coati.trainer import PPOTrainer from coati.utils import load_checkpoint +from coati.utils.reward_score import * from transformers import AutoModelForCausalLM, AutoTokenizer import colossalai @@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy logger = get_dist_logger() +# default settings for response format tags, overwrite it in chat_template definition if needed +response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, +} + def train(args): + global response_format_tags lora_config = None if args.lora_config is not None: lora_config = LoraConfig.from_file(args.lora_config) @@ -61,28 +79,36 @@ def train(args): torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, use_flash_attention_2=True, local_files_only=True, + trust_remote_code=True, ) ref_model = AutoModelForCausalLM.from_pretrained( args.pretrain, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, use_flash_attention_2=True, local_files_only=True, + trust_remote_code=True, ) - reward_model = RewardModel( - args.rm_pretrain, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - use_flash_attention_2=True, - ) + if not args.no_neural_reward_model: + reward_model = RewardModel( + args.rm_pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + trust_remote_code=True, + ) critic = Critic( args.rm_pretrain, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, use_flash_attention_2=True, + trust_remote_code=True, ) coordinator.print_on_master(msg="Flash-attention enabled successfully") else: - actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) - ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) - reward_model = RewardModel(args.rm_pretrain) + actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True) + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, local_files_only=True, trust_remote_code=True + ) + if not args.no_neural_reward_model: + reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True) critic = Critic(args.rm_pretrain) if args.lora_config is not None: @@ -112,6 +138,9 @@ def train(args): with open(args.conversation_template_config, "r", encoding="utf8") as f: conversation_template_config = json.load(f) dist.barrier() + if "response_format_tags" in conversation_template_config: + logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}") + response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags) conversation_template = setup_conversation_template( tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config ) @@ -245,7 +274,7 @@ def train(args): parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, - custom_policy=get_autopolicy(reward_model.model), + custom_policy=get_autopolicy(critic.model), ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -284,7 +313,8 @@ def train(args): actor_booster = Booster(plugin=plugin) ref_booster = Booster(plugin=plugin) - rm_booster = Booster(plugin=custom_plugin) + if not args.no_neural_reward_model: + rm_booster = Booster(plugin=custom_plugin) critic_booster = Booster(plugin=custom_plugin) default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 @@ -302,7 +332,28 @@ def train(args): lr_scheduler=critic_lr_scheduler, dataloader=train_prompt_dataloader, ) - reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader) + if not args.no_neural_reward_model: + reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader) + else: + if args.reward_functions: + reward_fn_list = [] + for reward_fn in args.reward_functions: + """ + To define custom reward function, you can define your functions under: + colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py + and use it here by mofiying the following line: + """ + if reward_fn == "gsm8k_reward_fn": + reward_fn_list.append(gsm8k_reward_fn) + elif reward_fn == "math_competition_reward_fn": + reward_fn_list.append(math_competition_reward_fn) + else: + raise ValueError(f"Unknown reward function {reward_fn}") + reward_fn_list.append(eval(reward_fn)) + reward_model = RLVRRewardModel( + reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags + ) + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader) torch.set_default_dtype(torch.float) @@ -481,9 +532,11 @@ if __name__ == "__main__": parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--no_neural_reward_model", default=False, action="store_true") parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--critic_checkpoint_path", type=str, default=None) parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path") + parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use") parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") parser.add_argument("--num_episodes", type=int, default=1) parser.add_argument("--num_collect_steps", type=int, default=2) diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index ac40ae821..a199071c2 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -2,7 +2,7 @@ transformers==4.39.3 tqdm datasets==2.14.7 loralib -colossalai>=0.4.0 +colossalai>=0.4.7 torch>=2.1.0 langchain tokenizers diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py index e50b20b6b..dd6095258 100644 --- a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py +++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py @@ -20,6 +20,15 @@ prompt_seed = { }, ] } +prompt_rlvr_seed = { + "messages": [ + { + "from": "user", + "content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?", + }, + ], + "gt_answer": "4", +} preference_seed = { "context": [ {"from": "user", "content": "What kind of noises did dinosaurs make?"}, @@ -72,6 +81,8 @@ if __name__ == "__main__": seed = sft_seed elif args.data_type == "prompt": seed = prompt_seed + elif args.data_type == "prompt_rlvr": + seed = prompt_rlvr_seed elif args.data_type == "preference": seed = preference_seed elif args.data_type == "kto": diff --git a/applications/ColossalChat/tests/prepare_test_env.sh b/applications/ColossalChat/tests/prepare_test_env.sh new file mode 100755 index 000000000..4733bb054 --- /dev/null +++ b/applications/ColossalChat/tests/prepare_test_env.sh @@ -0,0 +1,16 @@ +# run under /ColossalAI/applications/ColossalChat +export NCCL_SHM_DISABLE=1 +export MAX_JOBS=1 +export PRETRAINED_MODEL_PATH=./models +export SFT_DATASET=./sft_data +export PROMPT_DATASET=./prompt_data +export PROMPT_RLVR_DATASET=./prompt_data +export PREFERENCE_DATASET=./preference_data +export KTO_DATASET=./kto_data +mkdir models +mkdir sft_data +mkdir prompt_data +mkdir preference_data +mkdir kto_data +# ./tests/test_data_preparation.sh +# ./tests/test_train.sh diff --git a/applications/ColossalChat/tests/test_data_preparation.sh b/applications/ColossalChat/tests/test_data_preparation.sh index 427c3952b..5bc05c4ec 100755 --- a/applications/ColossalChat/tests/test_data_preparation.sh +++ b/applications/ColossalChat/tests/test_data_preparation.sh @@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then fi if [ -z "$PROMPT_DATASET" ]; then - echo "Please set \$PROMPT_DATASET to the path to prompts." + echo "Please set \$PROMPT_DATASET to the path to prompts dataset." + exit 1 +fi + +if [ -z "$PROMPT_RLVR_DATASET" ]; then + echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels." exit 1 fi @@ -69,6 +74,8 @@ get_data_input_dirs() { echo "$SFT_DATASET" elif [[ $data_type == "prompt" ]]; then echo "$PROMPT_DATASET" + elif [[ $data_type == "prompt_rlvr" ]]; then + echo "$PROMPT_RLVR_DATASET" elif [[ $data_type == "preference" ]]; then echo "$PREFERENCE_DATASET" elif [[ $data_type == "kto" ]]; then @@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \ --data_dir $(get_data_input_dirs prompt) \ --data_type "prompt" +python $TEST_DIR/generate_dummy_datasets_for_testing.py \ + --data_dir $(get_data_input_dirs prompt_rlvr) \ + --data_type "prompt_rlvr" + python $TEST_DIR/generate_dummy_datasets_for_testing.py \ --data_dir $(get_data_input_dirs kto) \ --data_type "kto" @@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do done +echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..." + +# FIXME: This is a hack to skip tests that are not working +SKIPPED_TESTS=( +) + +# test prepare_prompt_dataset +for model in ${MODELS[@]}; do + data_type="prompt_rlvr" + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then + echo "[Test]: Skipped $model-$data_type" + continue + fi + cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache + jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl + arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow + data_input_dirs=$(get_data_input_dirs $data_type) + tokenizer_dir=$(get_tokenizer_dirs $model) + conversation_template=$(get_conversation_template_config $model) + for i in $(seq $NUM_RETRY); do + rm -rf $cache_dir + rm -rf $jsonl_dir + rm -rf $arrow_dir + echo "[Test]: $model-$data_type, attempt $i" + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \ + --type prompt \ + --data_input_dirs $data_input_dirs \ + --conversation_template_config $conversation_template \ + --tokenizer_dir $tokenizer_dir \ + --data_cache_dir $cache_dir \ + --data_jsonl_output_dir $jsonl_dir \ + --data_arrow_output_dir $arrow_dir \ + --max_length 400 \ + --num_samples_per_datafile 100 \ + --num_spliced_dataset_bins 1 + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$data_type" + exit 1 + fi +done + echo "[Test]: testing prepare_kto_dataset.py ..." # FIXME: This is a hack to skip tests that are not working diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh index 2935a6369..636bb2ad7 100755 --- a/applications/ColossalChat/tests/test_train.sh +++ b/applications/ColossalChat/tests/test_train.sh @@ -81,8 +81,242 @@ random_choice() { echo ${arr[$idx]} } +echo "[Test]: testing grpo ..." +SKIPPED_TESTS=( + llama-3d # 3d plugin doesn't support lora + llama-gemini # gemini doesn't support lora +) + +GRAD_CKPTS=('--grad_checkpoint') +REWARD_FLAG=('nn' 'vr') +for reward_type in ${REWARD_FLAG[@]}; do + for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue # gemini_auto plugin doesn't support generation + fi + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + rm_pretrain="--rm_pretrain $pretrain" + reward_fn="" + if [[ $reward_type == "vr" ]]; then + rm_pretrain="" + reward_fn="--reward_functions gsm8k_reward_fn" + fi + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + ebs='1' + conversation_template=$(get_conversation_template_config $model) + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + if [[ $plugin == "3d" ]]; then + tp='2' + bs='2' + ebs='1' + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto and gemini doesn't support generation + if [[ $plugin == "gemini_auto" ]]; then + # gemini-auto doesn't support generation + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i" + declare -a prompt_dataset=() + for split in $(seq -f "%05g" 0 0); do + if [[ $reward_type == "vr" ]]; then + prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split") + else + prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split") + fi + done + declare -a ptx_dataset=() + for split in $(seq -f "%05g" 0 0); do + ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") + done + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \ + --pretrain $pretrain \ + $rm_pretrain \ + --tokenizer_dir $tokenizer_dir \ + --conversation_template_config $conversation_template \ + --prompt_dataset ${prompt_dataset[@]} \ + --ptx_dataset ${ptx_dataset[@]} \ + --ptx_batch_size 1 \ + --num_generations 2 \ + --ptx_coef 0.2 \ + --save_path $MODEL_SAVE_PATH \ + $lora_config \ + --plugin $plugin \ + --num_episodes 5 \ + --num_collect_steps 1 \ + --num_update_steps 1 \ + --experience_batch_size $ebs \ + --train_batch_size $bs \ + --accumulation_steps $grad_accu \ + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 200 \ \ + --max_seq_len 10 \ + $reward_fn + # --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type" + exit 1 + fi + done + done + done +done + + +echo "[Test]: testing ppo ..." + + +SKIPPED_TESTS=( + llama-3d # 3d plugin doesn't support lora + llama-gemini # gemini doesn't support lora +) + +GRAD_CKPTS=('--grad_checkpoint') +REWARD_FLAG=('vr' 'nn') +for reward_type in ${REWARD_FLAG[@]}; do + for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue # gemini_auto plugin doesn't support generation + fi + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + reward_fn="" + no_nn="" + if [[ $reward_type == "vr" ]]; then + reward_fn="--reward_functions gsm8k_reward_fn" + no_nn="--no_neural_reward_model" + fi + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + ebs='2' + conversation_template=$(get_conversation_template_config $model) + if [[ $plugin == "zero2" ]]; then + lora_config=$LORA_CONFIG_ENABLE + else + lora_config="" + fi + if [[ $plugin == "3d" ]]; then + tp='2' + bs='2' + ebs='2' + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto and gemini doesn't support generation + if [[ $plugin == "gemini_auto" ]]; then + # gemini-auto doesn't support generation + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i" + declare -a prompt_dataset=() + for split in $(seq -f "%05g" 0 0); do + if [[ $reward_type == "vr" ]]; then + prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split") + else + prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split") + fi + done + declare -a ptx_dataset=() + for split in $(seq -f "%05g" 0 0); do + ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") + done + colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ + --pretrain $pretrain \ + --rm_pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --conversation_template_config $conversation_template \ + --prompt_dataset ${prompt_dataset[@]} \ + --ptx_dataset ${ptx_dataset[@]} \ + --ptx_batch_size 1 \ + --ptx_coef 0.2 \ + --save_path $MODEL_SAVE_PATH \ + $lora_config \ + --plugin $plugin \ + --num_episodes 5 \ + --num_collect_steps 1 \ + --num_update_steps 1 \ + --experience_batch_size $ebs \ + --train_batch_size $bs \ + --accumulation_steps $grad_accu \ + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --max_seq_len 10 \ + $reward_fn \ + $no_nn + # --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf ${MODEL_SAVE_PATH:?}/* + rm -rf ${MODELS_DIR:?}/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type" + exit 1 + fi + done + done + done +done echo "[Test]: testing sft ..." @@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do done done - -echo "[Test]: testing ppo ..." - - -SKIPPED_TESTS=( - llama-3d # 3d plugin doesn't support lora - llama-gemini # gemini doesn't support lora -) - -GRAD_CKPTS=('--grad_checkpoint') -for lora_rank in ${LORA_RANK[@]}; do - for model in ${MODELS[@]}; do - for plugin in ${PLUGINS[@]}; do - if [[ $plugin == "gemini_auto" ]]; then - echo "[Test]: Skipped $model-$plugin" - continue # gemini_auto plugin doesn't support generation - fi - if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then - echo "[Test]: Skipped $model-$plugin-$lora_rank" - continue - elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then - echo "[Test]: Skipped $model-$plugin" - continue - fi - pretrain=$(get_pretrain $model) - tokenizer_dir=$(get_tokenizer_dirs $model) - grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") - tp='1' - bs='4' - ebs='8' - conversation_template=$(get_conversation_template_config $model) - if [[ $plugin == "zero2" ]]; then - lora_config=$LORA_CONFIG_ENABLE - else - lora_config="" - fi - if [[ $plugin == "3d" ]]; then - tp='2' - bs='16' - ebs='32' - fi - grad_accu='2' - # gemini_auto and gemini doesn't support gradient accumulation - if [[ $plugin == "gemini_auto" ]]; then - grad_accu='1' - fi - # gemini_auto and gemini doesn't support generation - if [[ $plugin == "gemini_auto" ]]; then - # gemini-auto doesn't support generation - echo "[Test]: Skipped $model-$plugin" - continue - fi - for i in $(seq $NUM_RETRY); do - echo "[Test]: $model-$plugin-$lora_rank, attempt $i" - declare -a prompt_dataset=() - for split in $(seq -f "%05g" 0 0); do - prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split") - done - declare -a ptx_dataset=() - for split in $(seq -f "%05g" 0 0); do - ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") - done - colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ - --pretrain $pretrain \ - --rm_pretrain $pretrain \ - --tokenizer_dir $tokenizer_dir \ - --conversation_template_config $conversation_template \ - --prompt_dataset ${prompt_dataset[@]} \ - --ptx_dataset ${ptx_dataset[@]} \ - --ptx_batch_size 1 \ - --ptx_coef 0.2 \ - --save_path $MODEL_SAVE_PATH \ - $lora_config \ - --plugin $plugin \ - --num_episodes 5 \ - --num_collect_steps 1 \ - --num_update_steps 1 \ - --experience_batch_size $ebs \ - --train_batch_size $bs \ - --accumulation_steps $grad_accu \ - --lr 9e-6 \ - --mixed_precision "bf16" \ - --grad_clip 1.0 \ - --tp $tp \ - --lr 2e-5 \ - $grad_ckpt \ - --max_len 400 \ - --max_seq_len 10 \ - # --use_flash_attn - passed=$? - if [ $passed -eq 0 ]; then - rm -rf ${MODEL_SAVE_PATH:?}/* - rm -rf ${MODELS_DIR:?}/* - break - fi - done - if [ $passed -ne 0 ]; then - echo "[Test]: Failed $model-$plugin-$lora_rank" - exit 1 - fi - done - done -done - - echo "[Test]: testing DPO ..." SKIPPED_TESTS=( @@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do bs='2' if [[ $plugin == "3d" ]]; then tp='2' - bs='8' + bs='2' fi if [[ $plugin == "zero2" ]]; then lora_config=$LORA_CONFIG_ENABLE @@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do done - echo "[Test]: testing ORPO ..." SKIPPED_TESTS=( + llama-3d-0 llama-3d-20 # 3d plugin doesn't support lora llama-gemini_auto-20 # gemini_auto plugin doesn't support lora llama-gemini-20 # gemini doesn't support lora @@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do bs='2' if [[ $plugin == "3d" ]]; then tp='2' - bs='8' + bs='2' fi if [[ $plugin == "zero2" ]]; then lora_config=$LORA_CONFIG_ENABLE @@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do done done - - echo "[Test]: testing KTO ..." SKIPPED_TESTS=( + llama-3d-0 llama-3d-20 # 3d plugin doesn't support lora llama-gemini_auto-20 # gemini_auto plugin doesn't support lora llama-gemini-20 # gemini doesn't support lora @@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do bs='2' if [[ $plugin == "3d" ]]; then tp='2' - bs='8' + bs='2' fi if [[ $plugin == "zero2" ]]; then lora_config=$LORA_CONFIG_ENABLE