refactor tokenization

pull/5922/head
YeAnbang 2024-07-19 10:10:48 +00:00
parent 544b7a38a1
commit d49550fb49
9 changed files with 159 additions and 175 deletions

View File

@ -18,6 +18,7 @@ class Conversation:
chat_template: str
stop_ids: List[int]
end_of_assistant: str
roles = ["user", "assistant"]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@ -85,7 +86,7 @@ class Conversation:
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
assert role in ["user", "assistant"]
assert role in self.roles
self.messages.append({"role": role, "content": message})
def copy(self):

View File

@ -39,7 +39,7 @@ def supervised_tokenize_sft(
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
@ -52,41 +52,25 @@ def supervised_tokenize_sft(
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 0:
# Force to end with assistant response
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
lo, hi = 0, len(turns)
while lo < hi:
mid = (lo + hi) // 2
prompt = template.get_prompt(2 * turns[mid] - 1)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if max_length - 1 < len(tokenized):
hi = mid
else:
lo = mid + 1
target_turn_index = lo
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
# tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
prompt = template.get_prompt()
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages, prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
if tokenized is None:
return dict(
input_ids=None,
labels=None,
@ -96,45 +80,18 @@ def supervised_tokenize_sft(
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None:
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id:
# Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] + labels
labels = [-100] + labels
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [tokenizer.eos_token_id]
else:
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
# log decoded inputs and labels for debugging
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
@ -183,7 +140,7 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
@ -196,35 +153,28 @@ def tokenize_prompt_dataset(
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
target_turn = len(template.messages)
if target_turn % 2 != 1:
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
target_turn = target_turn - 1
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
# Skip overlength data
if max_length - 1 < len(tokenized):
if len(tokenized) > max_length:
return dict(
input_ids=None,
inputs_decode=None,
@ -235,47 +185,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=tokenizer.decode(tokenized),
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
):
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
if mask_token is None:
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
# no truncation applied
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
loss_mask = [0] * len(tokenized)
label_decode = []
for start, end in zip(starts[-1:], ends[-1:]):
# only the last round (chosen/rejected) counts
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start:end] = [1] * len(loss_mask[start:end])
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
# only the last round (chosen/rejected) is used to calculate loss
for i in range(starts[-1], ends[-1]):
loss_mask[i] = 1
label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
else:
loss_mask[-1] = 1
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@ -288,7 +223,7 @@ def tokenize_rlhf(
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
@ -298,24 +233,17 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()
for mess in context:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
# Concate adjacent message from the same role
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
else:
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{context}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
"Please make sure leading context starts and ends with a line from human\nLeading context: "
"Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages)
)
return dict(
@ -326,31 +254,27 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
round_of_context = int((len(template.messages) - 1) / 2)
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template)
rejected = deepcopy(template)
chosen_continuation = data_point["chosen"]
rejected_continuation = data_point["rejected"]
for round in range(len(chosen_continuation)):
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{chosen_continuation}"
)
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
for round in range(len(data_point["chosen"])):
from_str = data_point["chosen"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
chosen.append_message(from_str, data_point["chosen"][round]["content"])
for round in range(len(data_point["rejected"])):
from_str = data_point["rejected"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
rejected.append_message(from_str, data_point["rejected"][round]["content"])
for round in range(len(rejected_continuation)):
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{rejected_continuation}"
)
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
(
chosen_input_ids,
@ -361,16 +285,14 @@ def tokenize_rlhf(
rejected_label_decode,
) = (None, None, None, None, None, None)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
rejected_data_packed = apply_rlhf_data_format(
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
)
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
@ -387,7 +309,7 @@ def tokenize_rlhf(
rejected_label_decode=None,
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
@ -411,14 +333,13 @@ def tokenize_kto(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize a dataset for KTO training
The raw input data is conversation that have the following format
{
"prompt": [{"from": "human", "content": "xxx"}...],
"prompt": [{"from": "user", "content": "xxx"}...],
"completion": {"from": "assistant", "content": "xxx"},
"label": true/false
}
@ -427,21 +348,18 @@ def tokenize_kto(
the completion, which only contains the assistance's answer,
and a binary label, which indicates if the sample is prefered or not
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
prompt = data_point["prompt"]
completion = data_point["completion"]
template = deepcopy(conversation_template)
template.clear()
if prompt[0].get("from", None) != "human":
raise ValueError("conversation should start with human")
if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
raise ValueError("conversation should end with assistant")
for mess in prompt:
if mess.get("from", None) == "human":
if mess.get("from", None) == "user":
template.append_message("user", mess["content"])
elif mess.get("from", None) == "assistant":
template.append_message("assistant", mess["content"])

View File

@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
def tokenize_and_concatenate(
tokenizer: PreTrainedTokenizer,
text: List[str],
require_loss: List[bool],
max_length: int,
discard_non_loss_tokens_at_tail: bool = True,
):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
max_length: used to truncate the input ids
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
if the first round has already exeeded max length
- if the user query already exeeded max length, discard the sample
- if only the first assistant response exeeded max length, truncate the response to fit the max length
else keep the first several complete rounds of the conversations until max length is reached
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@ -106,10 +119,17 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
@ -125,6 +145,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
# if the tokenized content start with a leading space, we want to keep it in loss calculation
# e.g., Assistant: I am saying...
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
# e.g.,
# Assistant: # '\n' as line breaker
# I am saying...
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])

View File

@ -448,7 +448,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
@ -529,7 +529,7 @@ Below shows the preference dataset format used in training the reward model.
[
{"context": [
{
"from": "human",
"from": "user",
"content": "Introduce butterflies species in Oregon."
}
]
@ -598,7 +598,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
@ -606,7 +606,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
}
...
@ -761,12 +761,12 @@ For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examp
```jsonl
{
"prompt": [
{"from": "human", "content": "What are some cuss words in english?"},
{"from": "user", "content": "What are some cuss words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
},
{"from": "human", "content": "What's your favorite one?"},
{"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
"label": False, # whether the response is favorable or not

View File

@ -1 +1 @@
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}

View File

@ -0,0 +1 @@
{"prompt": [{"from": "user", "content": "What are some cuss words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "Ass."},"label": false}

View File

@ -1 +1 @@
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}

View File

@ -94,7 +94,7 @@ done
# Test DPO/PPO data Preparation
for model in ${MODELS[@]}; do
echo "Testing DPO/PPO data templating for $model"
echo "Testing DPO/RM data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating for $model"
echo "[Test]: Failed in the DPO/RM data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating test for $model"
echo "[Test]: Failed in the DPO/RM data templating test for $model"
exit 1
fi
done
# Test KTO data Preparation
for model in ${MODELS[@]}; do
echo "Testing KTO data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating test for $model"
exit 1
fi
done

View File

@ -62,3 +62,11 @@ if __name__ == "__main__":
assert any(
[rejected_lable in s for s in to_verify_lable_rejected]
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
elif args.data_type == "kto":
sample = data[0]
to_verify_data = to_verify_data[0]
for line in sample["prompt"]:
assert line["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
assert sample["label"] == to_verify_data["label"]