mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6119/head
parent
794e0d4f4a
commit
0bfb0d32a8
|
@ -7,7 +7,7 @@ from .loader import (
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward
|
from .tokenization_utils import tokenize_kto, tokenize_process_reward, tokenize_prompt, tokenize_rlhf, tokenize_sft
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"tokenize_prompt",
|
"tokenize_prompt",
|
||||||
|
@ -23,5 +23,5 @@ __all__ = [
|
||||||
"tokenize_kto",
|
"tokenize_kto",
|
||||||
"setup_conversation_template",
|
"setup_conversation_template",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"tokenize_process_reward"
|
"tokenize_process_reward",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,6 @@ import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
|
@ -12,7 +12,14 @@ import random
|
||||||
import time
|
import time
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward
|
from coati.dataset import (
|
||||||
|
setup_conversation_template,
|
||||||
|
tokenize_kto,
|
||||||
|
tokenize_process_reward,
|
||||||
|
tokenize_prompt,
|
||||||
|
tokenize_rlhf,
|
||||||
|
tokenize_sft,
|
||||||
|
)
|
||||||
from datasets import dataset_dict, load_dataset
|
from datasets import dataset_dict, load_dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
@ -28,7 +35,7 @@ def main():
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
default=None,
|
default=None,
|
||||||
choices=["sft", "prompt", "preference", "kto", 'prm'],
|
choices=["sft", "prompt", "preference", "kto", "prm"],
|
||||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
Loading…
Reference in New Issue