Merge branch 'feat/prm' of github.com:TongLi3701/ColossalAI into feat/prm

pull/6119/head
Tong Li 2024-11-14 08:31:43 +00:00
commit 9ff9dc3d4a
3 changed files with 10 additions and 4 deletions

View File

@ -7,7 +7,7 @@ from .loader import (
StatefulDistributedSampler, StatefulDistributedSampler,
load_tokenized_dataset, load_tokenized_dataset,
) )
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward from .tokenization_utils import tokenize_kto, tokenize_process_reward, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [ __all__ = [
"tokenize_prompt", "tokenize_prompt",

View File

@ -3,7 +3,6 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List from typing import Any, Dict, List
import torch.distributed as dist
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger

View File

@ -12,7 +12,14 @@ import random
import time import time
from multiprocessing import cpu_count from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward from coati.dataset import (
setup_conversation_template,
tokenize_kto,
tokenize_process_reward,
tokenize_prompt,
tokenize_rlhf,
tokenize_sft,
)
from datasets import dataset_dict, load_dataset from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -28,7 +35,7 @@ def main():
type=str, type=str,
required=True, required=True,
default=None, default=None,
choices=["sft", "prompt", "preference", "kto", 'prm'], choices=["sft", "prompt", "preference", "kto", "prm"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'", help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
) )
parser.add_argument( parser.add_argument(