mirror of https://github.com/hpcaitech/ColossalAI
fix style
parent
8a3ff4f315
commit
de1bf08ed0
|
@ -7,10 +7,10 @@ from .loader import (
|
|||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
)
|
||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_kto, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||
|
||||
__all__ = [
|
||||
"tokenize_prompt_dataset",
|
||||
"tokenize_prompt",
|
||||
"DataCollatorForPromptDataset",
|
||||
"is_rank_0",
|
||||
"DataCollatorForPreferenceDataset",
|
||||
|
@ -18,8 +18,7 @@ __all__ = [
|
|||
"DataCollatorForKTODataset",
|
||||
"StatefulDistributedSampler",
|
||||
"load_tokenized_dataset",
|
||||
"supervised_tokenize_pretrain",
|
||||
"supervised_tokenize_sft",
|
||||
"tokenize_sft",
|
||||
"tokenize_rlhf",
|
||||
"tokenize_kto",
|
||||
"setup_conversation_template",
|
||||
|
|
|
@ -23,11 +23,10 @@ IGNORE_INDEX = -100
|
|||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
|
||||
|
||||
def supervised_tokenize_sft(
|
||||
def tokenize_sft(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
|
@ -127,11 +126,10 @@ def supervised_tokenize_sft(
|
|||
)
|
||||
|
||||
|
||||
def tokenize_prompt_dataset(
|
||||
def tokenize_prompt(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
|
@ -215,7 +213,6 @@ def tokenize_rlhf(
|
|||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
|
|
|
@ -40,13 +40,7 @@ import random
|
|||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from coati.dataset import (
|
||||
setup_conversation_template,
|
||||
supervised_tokenize_sft,
|
||||
tokenize_kto,
|
||||
tokenize_prompt_dataset,
|
||||
tokenize_rlhf,
|
||||
)
|
||||
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
@ -205,9 +199,9 @@ def main():
|
|||
)
|
||||
|
||||
if args.type == "sft":
|
||||
preparation_function = supervised_tokenize_sft
|
||||
preparation_function = tokenize_sft
|
||||
elif args.type == "prompt":
|
||||
preparation_function = tokenize_prompt_dataset
|
||||
preparation_function = tokenize_prompt
|
||||
elif args.type == "preference":
|
||||
preparation_function = tokenize_rlhf
|
||||
elif args.type == "kto":
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
|
||||
SAVE_DIR=""
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
|
|
Loading…
Reference in New Issue