mirror of https://github.com/hpcaitech/ColossalAI
refactor tokenization
parent
544b7a38a1
commit
d49550fb49
|
@ -18,6 +18,7 @@ class Conversation:
|
|||
chat_template: str
|
||||
stop_ids: List[int]
|
||||
end_of_assistant: str
|
||||
roles = ["user", "assistant"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
|
||||
|
@ -85,7 +86,7 @@ class Conversation:
|
|||
Raises:
|
||||
AssertionError: If the role is not 'user' or 'assistant'.
|
||||
"""
|
||||
assert role in ["user", "assistant"]
|
||||
assert role in self.roles
|
||||
self.messages.append({"role": role, "content": message})
|
||||
|
||||
def copy(self):
|
||||
|
|
|
@ -39,7 +39,7 @@ def supervised_tokenize_sft(
|
|||
|
||||
Args:
|
||||
data_point: the data point of the following format
|
||||
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
tokenizer: the tokenizer whose
|
||||
conversation_template: the conversation template to apply
|
||||
ignore_index: the ignore index when calculate loss during training
|
||||
|
@ -52,41 +52,25 @@ def supervised_tokenize_sft(
|
|||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
for mess in messages:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
if len(template.messages) % 2 != 0:
|
||||
# Force to end with assistant response
|
||||
template.messages = template.messages[0:-1]
|
||||
|
||||
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||
turns = [i for i in range(1, len(messages) // 2 + 1)]
|
||||
|
||||
lo, hi = 0, len(turns)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
prompt = template.get_prompt(2 * turns[mid] - 1)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
if max_length - 1 < len(tokenized):
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 1
|
||||
target_turn_index = lo
|
||||
|
||||
# The tokenized length for first turn already exceeds `max_length - 1`.
|
||||
if target_turn_index - 1 < 0:
|
||||
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
|
||||
# tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
|
||||
prompt = template.get_prompt()
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages, prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
|
||||
if tokenized is None:
|
||||
return dict(
|
||||
input_ids=None,
|
||||
labels=None,
|
||||
|
@ -96,45 +80,18 @@ def supervised_tokenize_sft(
|
|||
seq_category=None,
|
||||
)
|
||||
|
||||
target_turn = turns[target_turn_index - 1]
|
||||
prompt = template.get_prompt(2 * target_turn)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
|
||||
labels = [ignore_index] * len(tokenized)
|
||||
for start, end in zip(starts, ends):
|
||||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
labels = labels + [ignore_index]
|
||||
labels[start:end] = tokenized[start:end]
|
||||
|
||||
# truncate the sequence at the last token that requires loss calculation
|
||||
to_truncate_len = 0
|
||||
for i in range(len(tokenized) - 1, -1, -1):
|
||||
if labels[i] == ignore_index:
|
||||
to_truncate_len += 1
|
||||
else:
|
||||
break
|
||||
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
|
||||
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
||||
labels = labels[: len(labels) - to_truncate_len]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
# Some chat templates already include bos token
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
labels = [ignore_index] + labels
|
||||
labels = [-100] + labels
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
# Force to add eos token at the end of the tokenized sequence
|
||||
if tokenized[-1] != tokenizer.eos_token_id:
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
labels = labels + [tokenizer.eos_token_id]
|
||||
else:
|
||||
labels[-1] = tokenizer.eos_token_id
|
||||
|
||||
# For some model without bos/eos may raise the following errors
|
||||
# log decoded inputs and labels for debugging
|
||||
inputs_decode = tokenizer.decode(tokenized)
|
||||
start = 0
|
||||
end = 0
|
||||
|
@ -183,7 +140,7 @@ def tokenize_prompt_dataset(
|
|||
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
|
||||
Args:
|
||||
data_point: the data point of the following format
|
||||
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
||||
tokenizer: the tokenizer whose
|
||||
conversation_template: the conversation template to apply
|
||||
ignore_index: the ignore index when calculate loss during training
|
||||
|
@ -196,35 +153,28 @@ def tokenize_prompt_dataset(
|
|||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
for mess in messages:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
||||
target_turn = len(template.messages)
|
||||
if target_turn % 2 != 1:
|
||||
if len(template.messages) % 2 != 1:
|
||||
# exclude the answer if provided. keep only the prompt
|
||||
target_turn = target_turn - 1
|
||||
template.messages = template.messages[:-1]
|
||||
|
||||
# Prepare data
|
||||
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
|
||||
# Skip overlength data
|
||||
if max_length - 1 < len(tokenized):
|
||||
if len(tokenized) > max_length:
|
||||
return dict(
|
||||
input_ids=None,
|
||||
inputs_decode=None,
|
||||
|
@ -235,47 +185,32 @@ def tokenize_prompt_dataset(
|
|||
# `inputs_decode` can be used to check whether the tokenization method is true.
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
inputs_decode=tokenizer.decode(tokenized),
|
||||
inputs_decode=prompt,
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
)
|
||||
|
||||
|
||||
def apply_rlhf_data_format(
|
||||
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
|
||||
):
|
||||
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
||||
target_turn = int(len(template.messages) / 2)
|
||||
prompt = template.get_prompt(target_turn * 2)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
loss_mask = [0] * len(tokenized)
|
||||
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
||||
if mask_token is None:
|
||||
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
|
||||
# no truncation applied
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
|
||||
|
||||
loss_mask = [0] * len(tokenized)
|
||||
label_decode = []
|
||||
for start, end in zip(starts[-1:], ends[-1:]):
|
||||
# only the last round (chosen/rejected) counts
|
||||
if end == len(tokenized):
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
loss_mask = loss_mask + [1]
|
||||
loss_mask[start:end] = [1] * len(loss_mask[start:end])
|
||||
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
|
||||
# only the last round (chosen/rejected) is used to calculate loss
|
||||
for i in range(starts[-1], ends[-1]):
|
||||
loss_mask[i] = 1
|
||||
label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
|
||||
if tokenizer.bos_token_id is not None:
|
||||
if tokenized[0] != tokenizer.bos_token_id:
|
||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||
loss_mask = [0] + loss_mask
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
# Force to add eos token at the end of the tokenized sequence
|
||||
if tokenized[-1] != tokenizer.eos_token_id:
|
||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||
loss_mask = loss_mask + [1]
|
||||
else:
|
||||
loss_mask[-1] = 1
|
||||
|
||||
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
|
||||
|
||||
|
||||
|
@ -288,7 +223,7 @@ def tokenize_rlhf(
|
|||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
A tokenization function to tokenize an original pretraining data point as following:
|
||||
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
|
||||
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
|
||||
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
|
||||
"""
|
||||
if ignore_index is None:
|
||||
|
@ -298,24 +233,17 @@ def tokenize_rlhf(
|
|||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
for mess in context:
|
||||
from_str = mess["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
|
||||
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
|
||||
# Concate adjacent message from the same role
|
||||
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
|
||||
else:
|
||||
template.append_message(from_str, mess["content"])
|
||||
for idx, mess in enumerate(context):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{context}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
if len(template.messages) % 2 != 1:
|
||||
warnings.warn(
|
||||
"Please make sure leading context starts and ends with a line from human\nLeading context: "
|
||||
"Please make sure leading context starts and ends with a line from user\nLeading context: "
|
||||
+ str(template.messages)
|
||||
)
|
||||
return dict(
|
||||
|
@ -326,31 +254,27 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
round_of_context = int((len(template.messages) - 1) / 2)
|
||||
|
||||
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
|
||||
assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
|
||||
chosen = deepcopy(template)
|
||||
rejected = deepcopy(template)
|
||||
chosen_continuation = data_point["chosen"]
|
||||
rejected_continuation = data_point["rejected"]
|
||||
for round in range(len(chosen_continuation)):
|
||||
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{chosen_continuation}"
|
||||
)
|
||||
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
|
||||
|
||||
for round in range(len(data_point["chosen"])):
|
||||
from_str = data_point["chosen"][round]["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
chosen.append_message(from_str, data_point["chosen"][round]["content"])
|
||||
|
||||
for round in range(len(data_point["rejected"])):
|
||||
from_str = data_point["rejected"][round]["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = "user"
|
||||
elif from_str.lower() == "assistant":
|
||||
from_str = "assistant"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {from_str.lower()}")
|
||||
rejected.append_message(from_str, data_point["rejected"][round]["content"])
|
||||
for round in range(len(rejected_continuation)):
|
||||
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{rejected_continuation}"
|
||||
)
|
||||
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
|
||||
|
||||
(
|
||||
chosen_input_ids,
|
||||
|
@ -361,16 +285,14 @@ def tokenize_rlhf(
|
|||
rejected_label_decode,
|
||||
) = (None, None, None, None, None, None)
|
||||
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
|
@ -387,7 +309,7 @@ def tokenize_rlhf(
|
|||
rejected_label_decode=None,
|
||||
)
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
|
||||
if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
|
@ -411,14 +333,13 @@ def tokenize_kto(
|
|||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
Tokenize a dataset for KTO training
|
||||
The raw input data is conversation that have the following format
|
||||
{
|
||||
"prompt": [{"from": "human", "content": "xxx"}...],
|
||||
"prompt": [{"from": "user", "content": "xxx"}...],
|
||||
"completion": {"from": "assistant", "content": "xxx"},
|
||||
"label": true/false
|
||||
}
|
||||
|
@ -427,21 +348,18 @@ def tokenize_kto(
|
|||
the completion, which only contains the assistance's answer,
|
||||
and a binary label, which indicates if the sample is prefered or not
|
||||
"""
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
prompt = data_point["prompt"]
|
||||
completion = data_point["completion"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0].get("from", None) != "human":
|
||||
raise ValueError("conversation should start with human")
|
||||
if prompt[0].get("from", None) != "user":
|
||||
raise ValueError("conversation should start with user")
|
||||
if completion.get("from", None) != "assistant":
|
||||
raise ValueError("conversation should end with assistant")
|
||||
|
||||
for mess in prompt:
|
||||
if mess.get("from", None) == "human":
|
||||
if mess.get("from", None) == "user":
|
||||
template.append_message("user", mess["content"])
|
||||
elif mess.get("from", None) == "assistant":
|
||||
template.append_message("assistant", mess["content"])
|
||||
|
|
|
@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
|
|||
return -1
|
||||
|
||||
|
||||
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
|
||||
def tokenize_and_concatenate(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text: List[str],
|
||||
require_loss: List[bool],
|
||||
max_length: int,
|
||||
discard_non_loss_tokens_at_tail: bool = True,
|
||||
):
|
||||
"""
|
||||
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
|
||||
|
||||
|
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
|
|||
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
|
||||
text (List[str]): The list of texts to tokenize.
|
||||
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
|
||||
max_length: used to truncate the input ids
|
||||
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
|
||||
|
||||
if the first round has already exeeded max length
|
||||
- if the user query already exeeded max length, discard the sample
|
||||
- if only the first assistant response exeeded max length, truncate the response to fit the max length
|
||||
else keep the first several complete rounds of the conversations until max length is reached
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
|
||||
|
@ -106,10 +119,17 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
|
|||
loss_ends = []
|
||||
for s, r in zip(text, require_loss):
|
||||
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
||||
if r:
|
||||
loss_starts.append(len(input_ids))
|
||||
loss_ends.append(len(input_ids) + len(tokenized))
|
||||
input_ids.extend(tokenized)
|
||||
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
||||
if r:
|
||||
loss_starts.append(len(input_ids))
|
||||
loss_ends.append(len(input_ids) + len(tokenized))
|
||||
input_ids.extend(tokenized)
|
||||
if loss_starts[0] >= max_length:
|
||||
return None, None, None
|
||||
if discard_non_loss_tokens_at_tail:
|
||||
input_ids = input_ids[: loss_ends[-1]]
|
||||
input_ids = input_ids[:max_length]
|
||||
loss_ends[-1] = min(max_length, loss_ends[-1])
|
||||
return input_ids, loss_starts, loss_ends
|
||||
|
||||
|
||||
|
@ -125,6 +145,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
|
|||
content_length = (
|
||||
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
|
||||
)
|
||||
# if the tokenized content start with a leading space, we want to keep it in loss calculation
|
||||
# e.g., Assistant: I am saying...
|
||||
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
|
||||
# e.g.,
|
||||
# Assistant: # '\n' as line breaker
|
||||
# I am saying...
|
||||
if prompt[first_occur - 1] != " ":
|
||||
chunks.append(prompt[start_idx:first_occur])
|
||||
chunks.append(prompt[first_occur : first_occur + content_length])
|
||||
|
|
|
@ -448,7 +448,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
|
|||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
},
|
||||
{
|
||||
|
@ -529,7 +529,7 @@ Below shows the preference dataset format used in training the reward model.
|
|||
[
|
||||
{"context": [
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "Introduce butterflies species in Oregon."
|
||||
}
|
||||
]
|
||||
|
@ -598,7 +598,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
|
|||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
||||
|
||||
|
||||
```json
|
||||
|
@ -606,7 +606,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
|
|||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "human",
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
}
|
||||
...
|
||||
|
@ -761,12 +761,12 @@ For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examp
|
|||
```jsonl
|
||||
{
|
||||
"prompt": [
|
||||
{"from": "human", "content": "What are some cuss words in english?"},
|
||||
{"from": "user", "content": "What are some cuss words in english?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
|
||||
},
|
||||
{"from": "human", "content": "What's your favorite one?"},
|
||||
{"from": "user", "content": "What's your favorite one?"},
|
||||
],
|
||||
"completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
|
||||
"label": False, # whether the response is favorable or not
|
||||
|
|
|
@ -1 +1 @@
|
|||
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
|
||||
{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
{"prompt": [{"from": "user", "content": "What are some cuss words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "Ass."},"label": false}
|
|
@ -1 +1 @@
|
|||
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
|
||||
{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
|
||||
|
|
|
@ -94,7 +94,7 @@ done
|
|||
|
||||
# Test DPO/PPO data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing DPO/PPO data templating for $model"
|
||||
echo "Testing DPO/RM data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
|
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
|
|||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating for $model"
|
||||
echo "[Test]: Failed in the DPO/RM data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the DPO data templating test for $model"
|
||||
echo "[Test]: Failed in the DPO/RM data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
# Test KTO data Preparation
|
||||
for model in ${MODELS[@]}; do
|
||||
echo "Testing KTO data templating for $model"
|
||||
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
pretrain=$(get_pretrain $model)
|
||||
conversation_template_config=$(get_conversation_template_config $model)
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
|
||||
--tokenizer_dir $pretrain \
|
||||
--conversation_template_config $conversation_template_config \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the KTO data templating for $model"
|
||||
exit 1
|
||||
fi
|
||||
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
|
||||
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
|
||||
passed=$?
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed in the KTO data templating test for $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -62,3 +62,11 @@ if __name__ == "__main__":
|
|||
assert any(
|
||||
[rejected_lable in s for s in to_verify_lable_rejected]
|
||||
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
|
||||
elif args.data_type == "kto":
|
||||
sample = data[0]
|
||||
to_verify_data = to_verify_data[0]
|
||||
for line in sample["prompt"]:
|
||||
assert line["content"] in to_verify_data["input_id_decode"]
|
||||
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
|
||||
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
|
||||
assert sample["label"] == to_verify_data["label"]
|
||||
|
|
Loading…
Reference in New Issue