mirror of https://github.com/hpcaitech/ColossalAI
update tokenization function
parent
dcb509c8e3
commit
1210dbea97
|
@ -1,13 +1,14 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
tokenization utils for constructing dataset for ppo, dpo, sft, rm
|
||||
Tokenization Utils for Constructing Dataset for RL.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coati.dataset.conversation import Conversation
|
||||
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
|
||||
from datasets import dataset_dict
|
||||
|
@ -393,3 +394,46 @@ def tokenize_kto(
|
|||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
|
||||
|
||||
def tokenize_process_reward(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
Tokenize function designed for tokenizing Math-Shepherd dataset.
|
||||
|
||||
The datapoint has the following format:
|
||||
{
|
||||
"input": problem + step-by-step solution,
|
||||
"label": problem + step-by-step solution with automatic label,
|
||||
"task": GSM8K or MATH
|
||||
}
|
||||
|
||||
"""
|
||||
input = data_point["input"]
|
||||
label = data_point["label"]
|
||||
|
||||
template = deepcopy(conversation_template)
|
||||
template.append_message("user", input)
|
||||
template.append_message("assistant", label)
|
||||
prompt = template.get_prompt(add_generation_prompt=True)
|
||||
reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal)
|
||||
tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
tokenized_tensor = torch.tensor(tokenized)
|
||||
loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id))
|
||||
|
||||
label = (tokenized_tensor * loss_mask).tolist()
|
||||
decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False)
|
||||
decoded_label = tokenizer.decode(label, skip_special_tokens=False)
|
||||
|
||||
return {
|
||||
"input": tokenized,
|
||||
"label": label,
|
||||
"loss_mask": loss_mask,
|
||||
"decoded_input": decoded_input,
|
||||
"decoded_label": decoded_label,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue