mirror of https://github.com/hpcaitech/ColossalAI
412 lines
16 KiB
Python
Executable File
412 lines
16 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
tokenization utils for constructing dataset for ppo, dpo, sft, rm
|
|
"""
|
|
|
|
import warnings
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, List, Union
|
|
|
|
from coati.dataset.conversation import Conversation
|
|
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
|
|
from datasets import dataset_dict
|
|
from torch.utils.data import ConcatDataset, Dataset
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
logger = get_dist_logger()
|
|
|
|
IGNORE_INDEX = -100
|
|
|
|
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
|
|
|
|
|
def supervised_tokenize_sft(
|
|
data_point: Dict[str, str],
|
|
tokenizer: PreTrainedTokenizer,
|
|
conversation_template: Conversation = None,
|
|
ignore_index: int = None,
|
|
max_length: int = 4096,
|
|
) -> Dict[str, Union[int, str, List[int]]]:
|
|
"""
|
|
A tokenization function to tokenize an original pretraining data point as following
|
|
and calculate corresponding labels for sft training:
|
|
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line end]Something here"
|
|
^
|
|
end_of_system_line_position
|
|
|
|
Args:
|
|
data_point: the data point of the following format
|
|
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
|
tokenizer: the tokenizer whose
|
|
conversation_template: the conversation template to apply
|
|
ignore_index: the ignore index when calculate loss during training
|
|
max_length: the maximum context length
|
|
"""
|
|
|
|
if ignore_index is None:
|
|
ignore_index = IGNORE_INDEX
|
|
|
|
messages = data_point["messages"]
|
|
template = deepcopy(conversation_template)
|
|
template.messages = []
|
|
|
|
for mess in messages:
|
|
from_str = mess["from"]
|
|
if from_str is None:
|
|
print(mess)
|
|
if from_str.lower() == "human":
|
|
from_str = "user"
|
|
elif from_str.lower() == "assistant":
|
|
from_str = "assistant"
|
|
else:
|
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
|
|
|
template.append_message(from_str, mess["content"])
|
|
|
|
if len(template.messages) % 2 != 0:
|
|
template.messages = template.messages[0:-1]
|
|
|
|
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
|
turns = [i for i in range(1, len(messages) // 2 + 1)]
|
|
|
|
lo, hi = 0, len(turns)
|
|
while lo < hi:
|
|
mid = (lo + hi) // 2
|
|
if max_length - 1 < len(
|
|
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
|
|
):
|
|
hi = mid
|
|
else:
|
|
lo = mid + 1
|
|
target_turn_index = lo
|
|
|
|
# The tokenized length for first turn already exceeds `max_length - 1`.
|
|
if target_turn_index - 1 < 0:
|
|
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
|
|
return dict(
|
|
input_ids=None,
|
|
labels=None,
|
|
inputs_decode=None,
|
|
labels_decode=None,
|
|
seq_length=None,
|
|
seq_category=None,
|
|
)
|
|
|
|
target_turn = turns[target_turn_index - 1]
|
|
prompt = template.get_prompt(2 * target_turn)
|
|
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
|
|
conversation_template.end_of_assistant)
|
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
|
|
|
labels = [ignore_index] * len(tokenized)
|
|
for start, end in zip(starts, ends):
|
|
if end == len(tokenized):
|
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
|
labels = labels + [ignore_index]
|
|
labels[start : end] = tokenized[start : end]
|
|
|
|
# truncate the sequence at the last token that requires loss calculation
|
|
to_truncate_len = 0
|
|
for i in range(len(tokenized) - 1, -1, -1):
|
|
if labels[i] == ignore_index:
|
|
to_truncate_len += 1
|
|
else:
|
|
break
|
|
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
|
labels = labels[: len(labels) - to_truncate_len]
|
|
|
|
if tokenizer.bos_token_id is not None:
|
|
if tokenized[0] != tokenizer.bos_token_id:
|
|
tokenized = [tokenizer.bos_token_id] + tokenized
|
|
labels = [ignore_index] + labels
|
|
|
|
if tokenizer.eos_token_id is not None:
|
|
# Force to add eos token at the end of the tokenized sequence
|
|
if tokenized[-1] != tokenizer.eos_token_id:
|
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
|
labels = labels + [tokenizer.eos_token_id]
|
|
else:
|
|
labels[-1] = tokenizer.eos_token_id
|
|
|
|
# For some model without bos/eos may raise the following errors
|
|
try:
|
|
inputs_decode = tokenizer.decode(tokenized)
|
|
start = 0
|
|
end = 0
|
|
label_decode = []
|
|
for i in range(len(labels)):
|
|
if labels[i] == ignore_index:
|
|
if start!=end:
|
|
label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False))
|
|
start = i
|
|
end = i
|
|
else:
|
|
end = i
|
|
if i == len(labels) - 1:
|
|
label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False))
|
|
|
|
except TypeError as e:
|
|
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
|
|
|
|
# Check if all labels are ignored, this may happen when the tokenized length is too long
|
|
if labels.count(ignore_index) == len(labels):
|
|
return dict(
|
|
input_ids=None,
|
|
labels=None,
|
|
inputs_decode=None,
|
|
labels_decode=None,
|
|
seq_length=None,
|
|
seq_category=None,
|
|
)
|
|
|
|
return dict(
|
|
input_ids=tokenized,
|
|
labels=labels,
|
|
inputs_decode=inputs_decode,
|
|
labels_decode=label_decode,
|
|
seq_length=len(tokenized),
|
|
seq_category=data_point["category"] if "category" in data_point else "None",
|
|
)
|
|
|
|
|
|
def tokenize_prompt_dataset(
|
|
data_point: Dict[str, str],
|
|
tokenizer: PreTrainedTokenizer,
|
|
conversation_template: Conversation = None,
|
|
ignore_index: int = None,
|
|
max_length: int = 4096,
|
|
) -> Dict[str, Union[int, str, List[int]]]:
|
|
"""
|
|
A tokenization function to tokenize an original pretraining data point as following for ppo training:
|
|
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
|
|
Args:
|
|
data_point: the data point of the following format
|
|
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
|
|
tokenizer: the tokenizer whose
|
|
conversation_template: the conversation template to apply
|
|
ignore_index: the ignore index when calculate loss during training
|
|
max_length: the maximum context length
|
|
"""
|
|
if ignore_index is None:
|
|
ignore_index = IGNORE_INDEX
|
|
|
|
messages = data_point["messages"]
|
|
template = deepcopy(conversation_template)
|
|
template.messages = []
|
|
|
|
for mess in messages:
|
|
from_str = mess["from"]
|
|
if from_str.lower() == "human":
|
|
from_str = "user"
|
|
elif from_str.lower() == "assistant":
|
|
from_str = "assistant"
|
|
else:
|
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
|
|
|
template.append_message(from_str, mess["content"])
|
|
|
|
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
|
|
target_turn = len(template.messages)
|
|
if target_turn % 2 != 1:
|
|
# exclude the answer if provided. keep only the prompt
|
|
target_turn = target_turn - 1
|
|
|
|
# Prepare data
|
|
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
|
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt,
|
|
conversation_template.end_of_assistant)
|
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
|
if tokenizer.bos_token_id is not None:
|
|
if tokenized[0] != tokenizer.bos_token_id:
|
|
tokenized = [tokenizer.bos_token_id] + tokenized
|
|
|
|
# Skip overlength data
|
|
if max_length - 1 < len(tokenized):
|
|
return dict(
|
|
input_ids=None,
|
|
inputs_decode=None,
|
|
seq_length=None,
|
|
seq_category=None,
|
|
)
|
|
|
|
# `inputs_decode` can be used to check whether the tokenization method is true.
|
|
return dict(
|
|
input_ids=tokenized,
|
|
inputs_decode=tokenizer.decode(tokenized),
|
|
seq_length=len(tokenized),
|
|
seq_category=data_point["category"] if "category" in data_point else "None",
|
|
)
|
|
|
|
|
|
def apply_rlhf_data_format(
|
|
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
|
|
):
|
|
target_turn = int(len(template.messages) / 2)
|
|
prompt = template.get_prompt(target_turn * 2)
|
|
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
|
|
template.end_of_assistant)
|
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
|
loss_mask = [0] * len(tokenized)
|
|
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
|
if mask_token is None:
|
|
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
|
|
|
|
label_decode = []
|
|
for start, end in zip(starts[-1:], ends[-1:]):
|
|
# only the last round (chosen/rejected) counts
|
|
if end == len(tokenized):
|
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
|
loss_mask = loss_mask + [1]
|
|
loss_mask[start : end] = [1] * len(loss_mask[start : end])
|
|
label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False))
|
|
if tokenizer.bos_token_id is not None:
|
|
if tokenized[0] != tokenizer.bos_token_id:
|
|
tokenized = [tokenizer.bos_token_id] + tokenized
|
|
loss_mask = [0] + loss_mask
|
|
|
|
if tokenizer.eos_token_id is not None:
|
|
# Force to add eos token at the end of the tokenized sequence
|
|
if tokenized[-1] != tokenizer.eos_token_id:
|
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
|
loss_mask = loss_mask + [1]
|
|
else:
|
|
loss_mask[-1] = 1
|
|
|
|
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
|
|
|
|
|
|
def tokenize_rlhf(
|
|
data_point: Dict[str, str],
|
|
tokenizer: PreTrainedTokenizer,
|
|
conversation_template: Conversation = None,
|
|
ignore_index: int = None,
|
|
max_length: int = 4096,
|
|
) -> Dict[str, Union[int, str, List[int]]]:
|
|
"""
|
|
A tokenization function to tokenize an original pretraining data point as following:
|
|
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
|
|
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
|
|
"""
|
|
if ignore_index is None:
|
|
ignore_index = IGNORE_INDEX
|
|
|
|
context = data_point["context"]
|
|
template = deepcopy(conversation_template)
|
|
template.clear()
|
|
|
|
for mess in context:
|
|
from_str = mess["from"]
|
|
if from_str.lower() == "human":
|
|
from_str = "user"
|
|
elif from_str.lower() == "assistant":
|
|
from_str = "assistant"
|
|
else:
|
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
|
|
|
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
|
|
# Concate adjacent message from the same role
|
|
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
|
|
else:
|
|
template.append_message(from_str, mess["content"])
|
|
|
|
if len(template.messages) % 2 != 1:
|
|
warnings.warn(
|
|
"Please make sure leading context starts and ends with a line from human\nLeading context: "
|
|
+ str(template.messages)
|
|
)
|
|
return dict(
|
|
chosen_input_ids=None,
|
|
chosen_loss_mask=None,
|
|
chosen_label_decode=None,
|
|
rejected_input_ids=None,
|
|
rejected_loss_mask=None,
|
|
rejected_label_decode=None,
|
|
)
|
|
round_of_context = int((len(template.messages) - 1) / 2)
|
|
|
|
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
|
|
chosen = deepcopy(template)
|
|
rejected = deepcopy(template)
|
|
|
|
for round in range(len(data_point["chosen"])):
|
|
from_str = data_point["chosen"][round]["from"]
|
|
if from_str.lower() == "human":
|
|
from_str = "user"
|
|
elif from_str.lower() == "assistant":
|
|
from_str = "assistant"
|
|
else:
|
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
|
chosen.append_message(from_str, data_point["chosen"][round]["content"])
|
|
|
|
for round in range(len(data_point["rejected"])):
|
|
from_str = data_point["rejected"][round]["from"]
|
|
if from_str.lower() == "human":
|
|
from_str = "user"
|
|
elif from_str.lower() == "assistant":
|
|
from_str = "assistant"
|
|
else:
|
|
raise ValueError(f"Unsupported role {from_str.lower()}")
|
|
rejected.append_message(from_str, data_point["rejected"][round]["content"])
|
|
|
|
(
|
|
chosen_input_ids,
|
|
chosen_loss_mask,
|
|
chosen_label_decode,
|
|
rejected_input_ids,
|
|
rejected_loss_mask,
|
|
rejected_label_decode,
|
|
) = (None, None, None, None, None, None)
|
|
if (
|
|
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
|
|
<= max_length - 1
|
|
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
|
|
<= max_length - 1
|
|
):
|
|
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
|
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
|
chosen_data_packed["input_ids"],
|
|
chosen_data_packed["loss_mask"],
|
|
chosen_data_packed["label_decode"],
|
|
)
|
|
|
|
rejected_data_packed = apply_rlhf_data_format(
|
|
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
|
)
|
|
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
|
rejected_data_packed["input_ids"],
|
|
rejected_data_packed["loss_mask"],
|
|
rejected_data_packed["label_decode"],
|
|
)
|
|
|
|
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
|
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
|
|
return dict(
|
|
chosen_input_ids=None,
|
|
chosen_loss_mask=None,
|
|
chosen_label_decode=None,
|
|
rejected_input_ids=None,
|
|
rejected_loss_mask=None,
|
|
rejected_label_decode=None,
|
|
)
|
|
|
|
return {
|
|
"chosen_input_ids": chosen_input_ids,
|
|
"chosen_loss_mask": chosen_loss_mask,
|
|
"chosen_label_decode": chosen_label_decode,
|
|
"rejected_input_ids": rejected_input_ids,
|
|
"rejected_loss_mask": rejected_loss_mask,
|
|
"rejected_label_decode": rejected_label_decode,
|
|
}
|
|
else:
|
|
return dict(
|
|
chosen_input_ids=None,
|
|
chosen_loss_mask=None,
|
|
chosen_label_decode=None,
|
|
rejected_input_ids=None,
|
|
rejected_loss_mask=None,
|
|
rejected_label_decode=None,
|
|
)
|