diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 020432b9e..533d0acad 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -tokenization utils for constructing dataset for ppo, dpo, sft, rm +Tokenization Utils for Constructing Dataset for RL. """ import warnings from copy import deepcopy from typing import Any, Dict, List, Union +import torch from coati.dataset.conversation import Conversation from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate from datasets import dataset_dict @@ -393,3 +394,46 @@ def tokenize_kto( "input_id_decode": decoded_full_prompt, "completion_decode": decoded_completion, } + + +def tokenize_process_reward( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + Tokenize function designed for tokenizing Math-Shepherd dataset. + + The datapoint has the following format: + { + "input": problem + step-by-step solution, + "label": problem + step-by-step solution with automatic label, + "task": GSM8K or MATH + } + + """ + input = data_point["input"] + label = data_point["label"] + + template = deepcopy(conversation_template) + template.append_message("user", input) + template.append_message("assistant", label) + prompt = template.get_prompt(add_generation_prompt=True) + reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal) + tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"] + + tokenized_tensor = torch.tensor(tokenized) + loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id)) + + label = (tokenized_tensor * loss_mask).tolist() + decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False) + decoded_label = tokenizer.decode(label, skip_special_tokens=False) + + return { + "input": tokenized, + "label": label, + "loss_mask": loss_mask, + "decoded_input": decoded_input, + "decoded_label": decoded_label, + }