mirror of https://github.com/hpcaitech/ColossalAI
fix style
parent
8a3ff4f315
commit
de1bf08ed0
|
@ -7,10 +7,10 @@ from .loader import (
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_kto, tokenize_prompt_dataset, tokenize_rlhf
|
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"tokenize_prompt_dataset",
|
"tokenize_prompt",
|
||||||
"DataCollatorForPromptDataset",
|
"DataCollatorForPromptDataset",
|
||||||
"is_rank_0",
|
"is_rank_0",
|
||||||
"DataCollatorForPreferenceDataset",
|
"DataCollatorForPreferenceDataset",
|
||||||
|
@ -18,8 +18,7 @@ __all__ = [
|
||||||
"DataCollatorForKTODataset",
|
"DataCollatorForKTODataset",
|
||||||
"StatefulDistributedSampler",
|
"StatefulDistributedSampler",
|
||||||
"load_tokenized_dataset",
|
"load_tokenized_dataset",
|
||||||
"supervised_tokenize_pretrain",
|
"tokenize_sft",
|
||||||
"supervised_tokenize_sft",
|
|
||||||
"tokenize_rlhf",
|
"tokenize_rlhf",
|
||||||
"tokenize_kto",
|
"tokenize_kto",
|
||||||
"setup_conversation_template",
|
"setup_conversation_template",
|
||||||
|
|
|
@ -23,11 +23,10 @@ IGNORE_INDEX = -100
|
||||||
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||||
|
|
||||||
|
|
||||||
def supervised_tokenize_sft(
|
def tokenize_sft(
|
||||||
data_point: Dict[str, str],
|
data_point: Dict[str, str],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
conversation_template: Conversation = None,
|
conversation_template: Conversation = None,
|
||||||
ignore_index: int = None,
|
|
||||||
max_length: int = 4096,
|
max_length: int = 4096,
|
||||||
) -> Dict[str, Union[int, str, List[int]]]:
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -127,11 +126,10 @@ def supervised_tokenize_sft(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_prompt_dataset(
|
def tokenize_prompt(
|
||||||
data_point: Dict[str, str],
|
data_point: Dict[str, str],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
conversation_template: Conversation = None,
|
conversation_template: Conversation = None,
|
||||||
ignore_index: int = None,
|
|
||||||
max_length: int = 4096,
|
max_length: int = 4096,
|
||||||
) -> Dict[str, Union[int, str, List[int]]]:
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -215,7 +213,6 @@ def tokenize_rlhf(
|
||||||
data_point: Dict[str, str],
|
data_point: Dict[str, str],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
conversation_template: Conversation = None,
|
conversation_template: Conversation = None,
|
||||||
ignore_index: int = None,
|
|
||||||
max_length: int = 4096,
|
max_length: int = 4096,
|
||||||
) -> Dict[str, Union[int, str, List[int]]]:
|
) -> Dict[str, Union[int, str, List[int]]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -40,13 +40,7 @@ import random
|
||||||
import time
|
import time
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
from coati.dataset import (
|
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||||
setup_conversation_template,
|
|
||||||
supervised_tokenize_sft,
|
|
||||||
tokenize_kto,
|
|
||||||
tokenize_prompt_dataset,
|
|
||||||
tokenize_rlhf,
|
|
||||||
)
|
|
||||||
from datasets import dataset_dict, load_dataset
|
from datasets import dataset_dict, load_dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
@ -205,9 +199,9 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.type == "sft":
|
if args.type == "sft":
|
||||||
preparation_function = supervised_tokenize_sft
|
preparation_function = tokenize_sft
|
||||||
elif args.type == "prompt":
|
elif args.type == "prompt":
|
||||||
preparation_function = tokenize_prompt_dataset
|
preparation_function = tokenize_prompt
|
||||||
elif args.type == "preference":
|
elif args.type == "preference":
|
||||||
preparation_function = tokenize_rlhf
|
preparation_function = tokenize_rlhf
|
||||||
elif args.type == "kto":
|
elif args.type == "kto":
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
|
SAVE_DIR=""
|
||||||
|
|
||||||
rm -rf $SAVE_DIR/cache
|
rm -rf $SAVE_DIR/cache
|
||||||
rm -rf $SAVE_DIR/jsonl
|
rm -rf $SAVE_DIR/jsonl
|
||||||
|
|
Loading…
Reference in New Issue