update tokenization function

pull/6119/head
Tong Li 2024-11-11 07:26:32 +00:00
parent dcb509c8e3
commit 1210dbea97
1 changed files with 45 additions and 1 deletions

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tokenization utils for constructing dataset for ppo, dpo, sft, rm
Tokenization Utils for Constructing Dataset for RL.
"""
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Union
import torch
from coati.dataset.conversation import Conversation
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
from datasets import dataset_dict
@ -393,3 +394,46 @@ def tokenize_kto(
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}
def tokenize_process_reward(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize function designed for tokenizing Math-Shepherd dataset.
The datapoint has the following format:
{
"input": problem + step-by-step solution,
"label": problem + step-by-step solution with automatic label,
"task": GSM8K or MATH
}
"""
input = data_point["input"]
label = data_point["label"]
template = deepcopy(conversation_template)
template.append_message("user", input)
template.append_message("assistant", label)
prompt = template.get_prompt(add_generation_prompt=True)
reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal)
tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]
tokenized_tensor = torch.tensor(tokenized)
loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id))
label = (tokenized_tensor * loss_mask).tolist()
decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False)
decoded_label = tokenizer.decode(label, skip_special_tokens=False)
return {
"input": tokenized,
"label": label,
"loss_mask": loss_mask,
"decoded_input": decoded_input,
"decoded_label": decoded_label,
}