fix style

pull/5922/head
YeAnbang 2024-07-26 10:07:15 +00:00
parent 8a3ff4f315
commit de1bf08ed0
4 changed files with 9 additions and 19 deletions

View File

@ -7,10 +7,10 @@ from .loader import (
StatefulDistributedSampler, StatefulDistributedSampler,
load_tokenized_dataset, load_tokenized_dataset,
) )
from .tokenization_utils import supervised_tokenize_sft, tokenize_kto, tokenize_prompt_dataset, tokenize_rlhf from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [ __all__ = [
"tokenize_prompt_dataset", "tokenize_prompt",
"DataCollatorForPromptDataset", "DataCollatorForPromptDataset",
"is_rank_0", "is_rank_0",
"DataCollatorForPreferenceDataset", "DataCollatorForPreferenceDataset",
@ -18,8 +18,7 @@ __all__ = [
"DataCollatorForKTODataset", "DataCollatorForKTODataset",
"StatefulDistributedSampler", "StatefulDistributedSampler",
"load_tokenized_dataset", "load_tokenized_dataset",
"supervised_tokenize_pretrain", "tokenize_sft",
"supervised_tokenize_sft",
"tokenize_rlhf", "tokenize_rlhf",
"tokenize_kto", "tokenize_kto",
"setup_conversation_template", "setup_conversation_template",

View File

@ -23,11 +23,10 @@ IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize_sft( def tokenize_sft(
data_point: Dict[str, str], data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None, conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096, max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]: ) -> Dict[str, Union[int, str, List[int]]]:
""" """
@ -127,11 +126,10 @@ def supervised_tokenize_sft(
) )
def tokenize_prompt_dataset( def tokenize_prompt(
data_point: Dict[str, str], data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None, conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096, max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]: ) -> Dict[str, Union[int, str, List[int]]]:
""" """
@ -215,7 +213,6 @@ def tokenize_rlhf(
data_point: Dict[str, str], data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None, conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096, max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]: ) -> Dict[str, Union[int, str, List[int]]]:
""" """

View File

@ -40,13 +40,7 @@ import random
import time import time
from multiprocessing import cpu_count from multiprocessing import cpu_count
from coati.dataset import ( from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
setup_conversation_template,
supervised_tokenize_sft,
tokenize_kto,
tokenize_prompt_dataset,
tokenize_rlhf,
)
from datasets import dataset_dict, load_dataset from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -205,9 +199,9 @@ def main():
) )
if args.type == "sft": if args.type == "sft":
preparation_function = supervised_tokenize_sft preparation_function = tokenize_sft
elif args.type == "prompt": elif args.type == "prompt":
preparation_function = tokenize_prompt_dataset preparation_function = tokenize_prompt
elif args.type == "preference": elif args.type == "preference":
preparation_function = tokenize_rlhf preparation_function = tokenize_rlhf
elif args.type == "kto": elif args.type == "kto":

View File

@ -1,4 +1,4 @@
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto" SAVE_DIR=""
rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/jsonl