support llama3

llama3
Tong Li 2024-04-20 22:23:34 +08:00
parent f6447fb459
commit 26b94a382b
3 changed files with 34 additions and 15 deletions

View File

@ -83,7 +83,7 @@ class Conversation:
}
conv = Conversation(
LLaMA2_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
@ -93,4 +93,14 @@ conv = Conversation(
seps=["<s>", "</s>"],
)
default_conversation = conv
LLaMA3_Conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
)
default_conversation = LLaMA3_Conv

View File

@ -14,6 +14,7 @@ from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import AutoTokenizer
from colossalai.logging import get_dist_logger
@ -71,7 +72,7 @@ def supervised_tokenize_pretrain(
def supervised_tokenize_sft(
data_point: Dict[str, str],
tokenizer: LlamaTokenizer,
tokenizer: AutoTokenizer,
conversation_template: Conversation = default_conversation,
ignore_index: int = None,
max_length: int = 4096,

View File

@ -13,7 +13,8 @@ from multiprocessing import cpu_count
from colossal_llama2.dataset.conversation import default_conversation
from colossal_llama2.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from transformers import AutoTokenizer
from transformers import AddedToken
from colossalai.logging import get_dist_logger
@ -47,20 +48,18 @@ def main():
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
parser.add_argument("--llama_version", type=int, default=3, help="LLaMA version")
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
if not os.path.exists(args.data_cache_dir):
os.makedirs(args.data_cache_dir)
if not os.path.exists(args.data_jsonl_output_dir):
os.makedirs(args.data_jsonl_output_dir)
if not os.path.exists(args.data_arrow_output_dir):
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
@ -83,11 +82,20 @@ def main():
train_splits.append(f"train[{start}%:{end}%]")
# Prepare to the tokenizer.
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
if tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.unk_token = tokenizer.eos_token
list_dataset = load_dataset(
path="json",