mirror of https://github.com/hpcaitech/ColossalAI
support llama3
parent
f6447fb459
commit
26b94a382b
|
@ -83,7 +83,7 @@ class Conversation:
|
|||
}
|
||||
|
||||
|
||||
conv = Conversation(
|
||||
LLaMA2_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
|
@ -93,4 +93,14 @@ conv = Conversation(
|
|||
seps=["<s>", "</s>"],
|
||||
)
|
||||
|
||||
default_conversation = conv
|
||||
LLaMA3_Conv = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||
seps=["<|begin_of_text|>", "<|end_of_text|>"],
|
||||
)
|
||||
|
||||
default_conversation = LLaMA3_Conv
|
||||
|
|
|
@ -14,6 +14,7 @@ from datasets import dataset_dict
|
|||
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
@ -71,7 +72,7 @@ def supervised_tokenize_pretrain(
|
|||
|
||||
def supervised_tokenize_sft(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: LlamaTokenizer,
|
||||
tokenizer: AutoTokenizer,
|
||||
conversation_template: Conversation = default_conversation,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
|
|
|
@ -13,7 +13,8 @@ from multiprocessing import cpu_count
|
|||
from colossal_llama2.dataset.conversation import default_conversation
|
||||
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AddedToken
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
@ -47,20 +48,18 @@ def main():
|
|||
)
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
|
||||
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
|
||||
parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_spliced_dataset_bins >= 100000:
|
||||
raise ValueError("Too many spliced divisions, must be smaller than 100000")
|
||||
|
||||
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_jsonl_output_dir
|
||||
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
|
||||
assert not os.path.exists(
|
||||
args.data_arrow_output_dir
|
||||
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
if not os.path.exists(args.data_cache_dir):
|
||||
os.makedirs(args.data_cache_dir)
|
||||
if not os.path.exists(args.data_jsonl_output_dir):
|
||||
os.makedirs(args.data_jsonl_output_dir)
|
||||
if not os.path.exists(args.data_arrow_output_dir):
|
||||
os.makedirs(args.data_arrow_output_dir)
|
||||
|
||||
# Prepare to all input datasets
|
||||
input_data_paths = []
|
||||
|
@ -83,11 +82,20 @@ def main():
|
|||
train_splits.append(f"train[{start}%:{end}%]")
|
||||
|
||||
# Prepare to the tokenizer.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
|
||||
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
|
||||
if args.llama_version == 2:
|
||||
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
if tokenizer.unk_token is not None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.unk_token = tokenizer.eos_token
|
||||
|
||||
list_dataset = load_dataset(
|
||||
path="json",
|
||||
|
|
Loading…
Reference in New Issue