mirror of https://github.com/hpcaitech/ColossalAI
update tokenization function
parent
dcb509c8e3
commit
1210dbea97
|
@ -1,13 +1,14 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
tokenization utils for constructing dataset for ppo, dpo, sft, rm
|
Tokenization Utils for Constructing Dataset for RL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from coati.dataset.conversation import Conversation
|
from coati.dataset.conversation import Conversation
|
||||||
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
|
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
|
||||||
from datasets import dataset_dict
|
from datasets import dataset_dict
|
||||||
|
@ -393,3 +394,46 @@ def tokenize_kto(
|
||||||
"input_id_decode": decoded_full_prompt,
|
"input_id_decode": decoded_full_prompt,
|
||||||
"completion_decode": decoded_completion,
|
"completion_decode": decoded_completion,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_process_reward(
|
||||||
|
data_point: Dict[str, str],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
conversation_template: Conversation = None,
|
||||||
|
max_length: int = 4096,
|
||||||
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
|
"""
|
||||||
|
Tokenize function designed for tokenizing Math-Shepherd dataset.
|
||||||
|
|
||||||
|
The datapoint has the following format:
|
||||||
|
{
|
||||||
|
"input": problem + step-by-step solution,
|
||||||
|
"label": problem + step-by-step solution with automatic label,
|
||||||
|
"task": GSM8K or MATH
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
input = data_point["input"]
|
||||||
|
label = data_point["label"]
|
||||||
|
|
||||||
|
template = deepcopy(conversation_template)
|
||||||
|
template.append_message("user", input)
|
||||||
|
template.append_message("assistant", label)
|
||||||
|
prompt = template.get_prompt(add_generation_prompt=True)
|
||||||
|
reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal)
|
||||||
|
tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
tokenized_tensor = torch.tensor(tokenized)
|
||||||
|
loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id))
|
||||||
|
|
||||||
|
label = (tokenized_tensor * loss_mask).tolist()
|
||||||
|
decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False)
|
||||||
|
decoded_label = tokenizer.decode(label, skip_special_tokens=False)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": tokenized,
|
||||||
|
"label": label,
|
||||||
|
"loss_mask": loss_mask,
|
||||||
|
"decoded_input": decoded_input,
|
||||||
|
"decoded_label": decoded_label,
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue