From 7a7e86987dd3f8a43e338ce9d6f6f4c019cc5073 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 27 May 2024 05:55:57 +0000 Subject: [PATCH] upgrade colossal-chat support tp_group>1, add sp for sft --- .../coati/dataset/conversation.py | 5 +- .../ColossalChat/coati/dataset/loader.py | 25 +- .../coati/dataset/tokenization_utils.py | 44 +- .../ColossalChat/coati/dataset/utils.py | 11 +- .../01-ai_Yi-1.5-9B-Chat.json | 8 + .../{Vicuna.json => 01-ai_Yi-34B.json} | 5 +- .../{chatGLM2.json => THUDM_chatglm2-6b.json} | 5 +- .../THUDM_chatglm3-6b.json | 8 + .../config/conversation_template/Yi.json | 7 - ...n => baichuan-inc_Baichuan2-13B-Chat.json} | 7 +- .../colossal-llama2.json | 5 +- .../deepseek-ai_DeepSeek-V2-Lite.json | 8 + .../config/conversation_template/llama2.json | 5 +- .../microsoft_phi-2.json | 8 + ...mistralai_Mixtral-8x7B-Instruct-v0.1.json} | 5 +- .../config/conversation_template/zephyr.json | 7 - applications/ColossalChat/examples/README.md | 13 +- .../prepare_dataset.py | 2 +- .../prepare_sft_dataset.sh | 17 +- .../examples/training_scripts/hostfile | 2 +- .../examples/training_scripts/train_sft.py | 90 +- .../examples/training_scripts/train_sft.sh | 48 +- applications/ColossalChat/requirements.txt | 5 +- .../ColossalChat/tests/test_chat_template.py | 38 + .../chat_template/01-ai_Yi-1.5-9B-Chat.json | 585 +++++ .../test_data/chat_template/01-ai_Yi-34B.json | 607 +++++ .../chat_template/Qwen_Qwen-7B-Chat.json | 603 +++++ .../chat_template/THUDM_chatglm2-6b.json | 715 ++++++ .../chat_template/THUDM_chatglm3-6b.json | 585 +++++ .../baichuan-inc_Baichuan2-13B-Chat.json | 697 ++++++ .../deepseek-ai_DeepSeek-V2-Lite.json | 581 +++++ .../chat_template/microsoft_phi-2.json | 2009 +++++++++++++++++ .../mistralai_Mixtral-8x7B-Instruct-v0.1.json | 919 ++++++++ 33 files changed, 7574 insertions(+), 105 deletions(-) create mode 100644 applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json rename applications/ColossalChat/config/conversation_template/{Vicuna.json => 01-ai_Yi-34B.json} (97%) rename applications/ColossalChat/config/conversation_template/{chatGLM2.json => THUDM_chatglm2-6b.json} (90%) create mode 100644 applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json delete mode 100644 applications/ColossalChat/config/conversation_template/Yi.json rename applications/ColossalChat/config/conversation_template/{Qwen.json => baichuan-inc_Baichuan2-13B-Chat.json} (88%) create mode 100644 applications/ColossalChat/config/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json create mode 100644 applications/ColossalChat/config/conversation_template/microsoft_phi-2.json rename applications/ColossalChat/config/conversation_template/{mistral.json => mistralai_Mixtral-8x7B-Instruct-v0.1.json} (93%) delete mode 100644 applications/ColossalChat/config/conversation_template/zephyr.json create mode 100644 applications/ColossalChat/tests/test_chat_template.py create mode 100644 applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-1.5-9B-Chat.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-34B.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/Qwen_Qwen-7B-Chat.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm2-6b.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm3-6b.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/baichuan-inc_Baichuan2-13B-Chat.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/deepseek-ai_DeepSeek-V2-Lite.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/microsoft_phi-2.json create mode 100644 applications/ColossalChat/tests/test_data/chat_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py index 15a33be93..ec46ea429 100755 --- a/applications/ColossalChat/coati/dataset/conversation.py +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -17,6 +17,7 @@ class Conversation: system_message: str chat_template: str stop_ids: List[int] + end_of_assistant: str @classmethod def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): @@ -24,7 +25,7 @@ class Conversation: Setup the conversation template from config """ tokenizer.chat_template = config["chat_template"] - conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"]) + conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]) conv.clear() return conv @@ -109,6 +110,8 @@ def setup_conversation_template( """ if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]): # Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually + if "end_of_assistant" not in chat_template_config: + raise ValueError("Please set the end of assistant token.") if "system_message" not in chat_template_config: logger.warning("No system message is provided, will not use system message.") if "chat_template" not in chat_template_config: diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 93cc1dab8..26f210eb2 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -248,9 +248,9 @@ class StatefulDistributedSampler(DistributedSampler): shuffle: bool = True, seed: int = 0, drop_last: bool = False, - use_tp: Optional[bool] = False, + tp_size: int = 1, ) -> None: - if not use_tp: + if not tp_size>1: super().__init__( dataset=dataset, num_replicas=num_replicas, @@ -261,14 +261,16 @@ class StatefulDistributedSampler(DistributedSampler): ) else: # adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62 - # TODO: support tp_group>1. will fix it later - num_replicas = 1 if rank is None: rank = dist.get_rank() + world_size = dist.get_world_size() + dp_size = world_size // tp_size # data parallel size + dp_rank = int(rank / tp_size) # data parallel rank if rank < 0: raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]") self.dataset = dataset self.num_replicas = num_replicas + self.dp_rank = dp_rank self.rank = rank self.epoch = 0 self.drop_last = drop_last @@ -287,10 +289,10 @@ class StatefulDistributedSampler(DistributedSampler): self.shuffle = shuffle self.seed = seed self.start_index = 0 - self.use_tp = use_tp + self.tp_size = tp_size def __iter__(self) -> Iterator: - if self.use_tp: + if self.tp_size > 1: # TODO Add support for tp_group not equal to 1 pass # adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96 @@ -316,10 +318,9 @@ class StatefulDistributedSampler(DistributedSampler): # subsample indices = indices[ - : self.total_size : self.num_replicas + self.dp_rank: self.dp_rank + self.total_size : self.num_replicas ] # num_replicas=tp_group=1, we only support tp_group==1 for now assert len(indices) == self.num_samples - return iter(indices) else: @@ -345,7 +346,7 @@ def setup_distributed_dataloader( num_workers: int = 0, collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, process_group: Optional[ProcessGroup] = None, - use_tp: Optional[bool] = False, + tp_size: Optional[int] = 1, **kwargs, ) -> DataLoader: """ @@ -353,14 +354,16 @@ def setup_distributed_dataloader( """ _kwargs = kwargs.copy() process_group = process_group or _get_default_group() + # world_size = tp_size * pp_size + assert process_group.size()%tp_size == 0, f"process_group.size()={process_group.size()} must be divisible by tp_size={tp_size}" sampler = StatefulDistributedSampler( dataset=dataset, - num_replicas=process_group.size() if not use_tp else 1, + num_replicas=int(process_group.size()/tp_size), rank=process_group.rank(), shuffle=shuffle, seed=seed, drop_last=drop_last, - use_tp=use_tp, + tp_size=tp_size, ) # Deterministic dataloader diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 7606bc2a9..78b14ce4b 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -55,6 +55,8 @@ def supervised_tokenize_sft( for mess in messages: from_str = mess["from"] + if from_str is None: + print(mess) if from_str.lower() == "human": from_str = "user" elif from_str.lower() == "assistant": @@ -95,17 +97,26 @@ def supervised_tokenize_sft( target_turn = turns[target_turn_index - 1] prompt = template.get_prompt(2 * target_turn) - chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt) + chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt, + conversation_template.end_of_assistant) tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) labels = [ignore_index] * len(tokenized) - label_decode = [] for start, end in zip(starts, ends): if end == len(tokenized): tokenized = tokenized + [tokenizer.eos_token_id] labels = labels + [ignore_index] - labels[start : end + 1] = tokenized[start : end + 1] - label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False)) + labels[start : end] = tokenized[start : end] + + # truncate the sequence at the last token that requires loss calculation + to_truncate_len = 0 + for i in range(len(tokenized) - 1, -1, -1): + if labels[i] == ignore_index: + to_truncate_len += 1 + else: + break + tokenized = tokenized[: len(tokenized) - to_truncate_len] + labels = labels[: len(labels) - to_truncate_len] if tokenizer.bos_token_id is not None: if tokenized[0] != tokenizer.bos_token_id: @@ -123,6 +134,20 @@ def supervised_tokenize_sft( # For some model without bos/eos may raise the following errors try: inputs_decode = tokenizer.decode(tokenized) + start = 0 + end = 0 + label_decode = [] + for i in range(len(labels)): + if labels[i] == ignore_index: + if start!=end: + label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False)) + start = i + end = i + else: + end = i + if i == len(labels) - 1: + label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False)) + except TypeError as e: raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}") @@ -191,7 +216,9 @@ def tokenize_prompt_dataset( # Prepare data prompt = template.get_prompt(target_turn, add_generation_prompt=True) - tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] + chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt, + conversation_template.end_of_assistant) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) if tokenizer.bos_token_id is not None: if tokenized[0] != tokenizer.bos_token_id: tokenized = [tokenizer.bos_token_id] + tokenized @@ -219,7 +246,8 @@ def apply_rlhf_data_format( ): target_turn = int(len(template.messages) / 2) prompt = template.get_prompt(target_turn * 2) - chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt) + chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt, + tempalte.end_of_assistant) tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) loss_mask = [0] * len(tokenized) mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id @@ -232,8 +260,8 @@ def apply_rlhf_data_format( if end == len(tokenized): tokenized = tokenized + [tokenizer.eos_token_id] loss_mask = loss_mask + [1] - loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1]) - label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False)) + loss_mask[start : end] = [1] * len(loss_mask[start : end]) + label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False)) if tokenizer.bos_token_id is not None: if tokenized[0] != tokenizer.bos_token_id: tokenized = [tokenizer.bos_token_id] + tokenized diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py index ada2afef0..eaef8af1a 100755 --- a/applications/ColossalChat/coati/dataset/utils.py +++ b/applications/ColossalChat/coati/dataset/utils.py @@ -113,20 +113,23 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re return input_ids, loss_starts, loss_ends -def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str): +def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str, end_of_assistant: str): # Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate start_idx = 0 chunks = [] require_loss = [] for line in messages: + content_length = len(line["content"]) first_occur = prompt.find(line["content"], start_idx) + if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length:]: + content_length = prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur if prompt[first_occur - 1] != " ": chunks.append(prompt[start_idx:first_occur]) - chunks.append(prompt[first_occur : first_occur + len(line["content"])]) + chunks.append(prompt[first_occur : first_occur + content_length]) else: chunks.append(prompt[start_idx : first_occur - 1]) - chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])]) - start_idx = first_occur + len(line["content"]) + chunks.append(prompt[first_occur - 1 : first_occur + content_length]) + start_idx = first_occur + content_length if line["role"].lower() == "assistant": require_loss.append(False) require_loss.append(True) diff --git a/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json b/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json new file mode 100644 index 000000000..85a726766 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/01-ai_Yi-1.5-9B-Chat.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 7 + ], + "end_of_assistant": "<|im_end|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/Vicuna.json b/applications/ColossalChat/config/conversation_template/01-ai_Yi-34B.json similarity index 97% rename from applications/ColossalChat/config/conversation_template/Vicuna.json rename to applications/ColossalChat/config/conversation_template/01-ai_Yi-34B.json index 2b00b6529..614c25ca6 100644 --- a/applications/ColossalChat/config/conversation_template/Vicuna.json +++ b/applications/ColossalChat/config/conversation_template/01-ai_Yi-34B.json @@ -3,5 +3,6 @@ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", "stop_ids": [ 2 - ] -} + ], + "end_of_assistant": "<|endoftext|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/chatGLM2.json b/applications/ColossalChat/config/conversation_template/THUDM_chatglm2-6b.json similarity index 90% rename from applications/ColossalChat/config/conversation_template/chatGLM2.json rename to applications/ColossalChat/config/conversation_template/THUDM_chatglm2-6b.json index a2638dbe7..9d8531753 100644 --- a/applications/ColossalChat/config/conversation_template/chatGLM2.json +++ b/applications/ColossalChat/config/conversation_template/THUDM_chatglm2-6b.json @@ -3,5 +3,6 @@ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", "stop_ids": [ 2 - ] -} + ], + "end_of_assistant": "<|im_end|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json b/applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json new file mode 100644 index 000000000..d791e1ae8 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ], + "end_of_assistant": "<|user|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/Yi.json b/applications/ColossalChat/config/conversation_template/Yi.json deleted file mode 100644 index 9716413b5..000000000 --- a/applications/ColossalChat/config/conversation_template/Yi.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", - "stop_ids": [ - 2 - ] -} diff --git a/applications/ColossalChat/config/conversation_template/Qwen.json b/applications/ColossalChat/config/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json similarity index 88% rename from applications/ColossalChat/config/conversation_template/Qwen.json rename to applications/ColossalChat/config/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json index 09f706ffe..9d8531753 100644 --- a/applications/ColossalChat/config/conversation_template/Qwen.json +++ b/applications/ColossalChat/config/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json @@ -2,6 +2,7 @@ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", "stop_ids": [ - null - ] -} + 2 + ], + "end_of_assistant": "<|im_end|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/colossal-llama2.json b/applications/ColossalChat/config/conversation_template/colossal-llama2.json index cc7f1e5d7..b9c17c1e2 100644 --- a/applications/ColossalChat/config/conversation_template/colossal-llama2.json +++ b/applications/ColossalChat/config/conversation_template/colossal-llama2.json @@ -3,5 +3,6 @@ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", "stop_ids": [ 2 - ] -} + ], + "end_of_assistant": "" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json b/applications/ColossalChat/config/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json new file mode 100644 index 000000000..89a9aed85 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 100001 + ], + "end_of_assistant": "<|end▁of▁sentence|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/llama2.json b/applications/ColossalChat/config/conversation_template/llama2.json index 80558f976..5fbe8b4fc 100644 --- a/applications/ColossalChat/config/conversation_template/llama2.json +++ b/applications/ColossalChat/config/conversation_template/llama2.json @@ -3,5 +3,6 @@ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", "stop_ids": [ 2 - ] -} + ], + "end_of_assistant": "" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/microsoft_phi-2.json b/applications/ColossalChat/config/conversation_template/microsoft_phi-2.json new file mode 100644 index 000000000..60ec8b763 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/microsoft_phi-2.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 50256 + ], + "end_of_assistant": "<|im_end|>" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/mistral.json b/applications/ColossalChat/config/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json similarity index 93% rename from applications/ColossalChat/config/conversation_template/mistral.json rename to applications/ColossalChat/config/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json index b48c3a3f2..9a7df645d 100644 --- a/applications/ColossalChat/config/conversation_template/mistral.json +++ b/applications/ColossalChat/config/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json @@ -3,5 +3,6 @@ "system_message": null, "stop_ids": [ 2 - ] -} + ], + "end_of_assistant": "" +} \ No newline at end of file diff --git a/applications/ColossalChat/config/conversation_template/zephyr.json b/applications/ColossalChat/config/conversation_template/zephyr.json deleted file mode 100644 index 2ab141111..000000000 --- a/applications/ColossalChat/config/conversation_template/zephyr.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", - "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", - "stop_ids": [ - 2 - ] -} diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index cfed3f1f3..3318e498b 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -338,7 +338,18 @@ In this code we provide a flexible way for users to set the conversation templat { "chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating, "system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added, - "stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically, + "end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format, + ``` + <|im_start|>system + system messages + + <|im_end|> + <|im_start|>user + How far is the moon? <|im_end|> + <|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>... + ``` + the end_of_assistant tokens are "<|im_end|>" + "stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicating the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically } ``` On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message), diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py index 64093f88d..04e613d0c 100644 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py @@ -226,7 +226,7 @@ def main(): "max_length": args.max_length, }, keep_in_memory=False, - num_proc=min(len(dataset), cpu_count()), + num_proc= min(len(dataset), cpu_count()), ) dataset = dataset.filter( diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh index cf937db2a..9bb332cac 100755 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh @@ -1,13 +1,22 @@ -SAVE_DIR="" +SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf" rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/arrow +# python prepare_dataset.py --type sft \ +# --data_input_dirs /home/yeanbang/data/experiment/dataset/sft_data/test/sft-data \ +# --conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json \ +# --tokenizer_dir "/mnt/jfs-hdd/home/data/models/ChatGlm-6B" \ +# --data_cache_dir $SAVE_DIR/cache \ +# --data_jsonl_output_dir $SAVE_DIR/jsonl \ +# --data_arrow_output_dir $SAVE_DIR/arrow \ + + python prepare_dataset.py --type sft \ - --data_input_dirs /PATH/TO/SFT/DATASET \ - --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ - --tokenizer_dir "" \ + --data_input_dirs /mnt/jfs-hdd/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train \ + --conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \ + --tokenizer_dir "/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" \ --data_cache_dir $SAVE_DIR/cache \ --data_jsonl_output_dir $SAVE_DIR/jsonl \ --data_arrow_output_dir $SAVE_DIR/arrow \ diff --git a/applications/ColossalChat/examples/training_scripts/hostfile b/applications/ColossalChat/examples/training_scripts/hostfile index d4118dda9..c46761d14 100755 --- a/applications/ColossalChat/examples/training_scripts/hostfile +++ b/applications/ColossalChat/examples/training_scripts/hostfile @@ -1 +1 @@ -10.20.1.82 +172.27.183.199 diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index ae20f2abc..5ee9b741e 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -14,13 +14,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam +import inspect +import sys +import torch.distributed as dist +from colossalai.logging import get_dist_logger +logger = get_dist_logger() def train(args): + print(colossalai.__version__, inspect.getfile(colossalai)) + print(sys.executable) # check lora compatibility if "gemini" in args.plugin and args.lora_rank > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") @@ -35,6 +42,38 @@ def train(args): # ============================== # Initialize Booster # ============================== + init_ctx = nullcontext() + with init_ctx: + model = AutoModelForCausalLM.from_pretrained(args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True) + # check if the hybrid parallel plugin is compatible with the model + # try: + # from colossalai.shardformer.policies.auto_policy import get_autopolicy + # policy = get_autopolicy(model) + # if policy is not None: + # if args.plugin in ['zero2', 'zero2_cpu']: + # # if compatible, set the plugin to hybrid, which use colo-attention + # args.plugin = 'hybrid' + # args.zero_stage = 2 + # if args.plugin == 'zero2_cpu': + # args.zero_cpu_offload = True + # else: + # args.zero_cpu_offload = False + # logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}") + # except NotImplementedError: + # logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead") + # if args.use_flash_attn: + # del model + # model = AutoModelForCausalLM.from_pretrained( + # args.pretrain, + # torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + # attn_implementation="flash_attention_2", + # trust_remote_code=True + # ) + if args.lora_rank > 0: + model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + if args.plugin == "ddp": """ Default torch ddp plugin without any acceleration, for @@ -47,7 +86,8 @@ def train(args): placement_policy="static", initial_scale=2**16, max_norm=args.grad_clip, - enable_gradient_accumulation=True, + enable_gradient_accumulation=True if args.accumulation_steps > 1 else False, + enable_flash_attention=args.use_flash_attn ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -55,6 +95,7 @@ def train(args): placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip, + enable_flash_attention=args.use_flash_attn ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( @@ -71,11 +112,16 @@ def train(args): cpu_offload=True, max_norm=args.grad_clip, ) - elif args.plugin == "3d": + elif args.plugin == "hybrid": plugin = HybridParallelPlugin( tp_size=args.tp, - pp_size=1, - zero_stage=0, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=True if args.sp > 1 else False, + cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False, parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, @@ -93,20 +139,7 @@ def train(args): # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # ) - init_ctx = nullcontext() - with init_ctx: - if args.use_flash_attn: - model = AutoModelForCausalLM.from_pretrained( - args.pretrain, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - use_flash_attention_2=True, - ) - coordinator.print_on_master(msg="Flash-attention enabled successfully") - else: - model = AutoModelForCausalLM.from_pretrained(args.pretrain) - if args.lora_rank > 0: - model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) - + if args.grad_checkpoint and args.lora_rank == 0: # lora layers are not supported by gradient checkpointing model.gradient_checkpointing_enable() @@ -131,6 +164,7 @@ def train(args): tokenizer.add_bos_token = False tokenizer.add_eos_token = False + tokenizer.padding_side = "right" coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}") @@ -156,8 +190,13 @@ def train(args): shuffle=True, drop_last=True, collate_fn=data_collator, - use_tp=args.tp > 1, + tp_size=args.tp, ) + # print(len(train_dataloader)) + # for batch in train_dataloader: + # print(dist.get_rank(), tokenizer.batch_decode(batch["input_ids"])) + # break + coordinator.print_on_master( f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" ) @@ -255,7 +294,7 @@ def train(args): # save model checkpoint after fitting on only rank0 coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True) + # booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True) coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") @@ -270,13 +309,18 @@ if __name__ == "__main__": "--plugin", type=str, default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"], + choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"], help="Choose which plugin to use", ) parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) + parser.add_argument("--zero_cpu_offload", default=False, action="store_true") + parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--dataset", nargs="+", default=[]) @@ -287,7 +331,7 @@ if __name__ == "__main__": parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_len", type=int, default=512) - parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument( "--lora_train_bias", diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh index d5c394377..e3bc12823 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.sh +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -17,22 +17,22 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { # export CUDA_VISIBLE_DEVICES=4,5,6 set_n_least_used_CUDA_VISIBLE_DEVICES 4 PROJECT_NAME="sft" -PARENT_SAVE_DIR="" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs -PARENT_CONFIG_FILE="" # Path to a folder to save training config logs -PRETRAINED_MODEL_PATH="" # huggingface or local model path -PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs +PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path declare -a dataset=( - YOUR/SFT/DATA/DIR/arrow/part-00000 - YOUR/SFT/DATA/DIR/arrow/part-00001 - YOUR/SFT/DATA/DIR/arrow/part-00002 - YOUR/SFT/DATA/DIR/arrow/part-00003 - YOUR/SFT/DATA/DIR/arrow/part-00004 - YOUR/SFT/DATA/DIR/arrow/part-00005 - YOUR/SFT/DATA/DIR/arrow/part-00006 - YOUR/SFT/DATA/DIR/arrow/part-00007 - YOUR/SFT/DATA/DIR/arrow/part-00008 - YOUR/SFT/DATA/DIR/arrow/part-00009 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008 + /home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009 ) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) @@ -40,6 +40,8 @@ FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" +echo $(which colossalai) +echo $(which python) # the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \ --pretrain $PRETRAINED_MODEL_PATH \ @@ -50,10 +52,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai --config_file $CONFIG_FILE \ --lora_rank 0 \ --plugin zero2 \ - --batch_size 8 \ - --max_epochs 1 \ - --accumulation_steps 1 \ - --lr 2e-5 \ - --max_len 2048 \ + --tp 1 \ + --pp 1 \ + --zero_stage 2 \ + --batch_size 4 \ + --max_epochs 3 \ + --accumulation_steps 4 \ + --lr 5e-5 \ + --max_len 400 \ --grad_checkpoint \ - --use_wandb + --use_wandb \ + --use_flash_attn diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index de5f6160e..3ad29e6e1 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,9 +1,8 @@ -transformers==4.34.1 -huggingface_hub==0.17.3 +transformers>=4.37.0 tqdm datasets loralib -colossalai>=0.3.6 +colossalai>=0.3.7 torch>=1.12.1 langchain tokenizers diff --git a/applications/ColossalChat/tests/test_chat_template.py b/applications/ColossalChat/tests/test_chat_template.py new file mode 100644 index 000000000..2395c1f3c --- /dev/null +++ b/applications/ColossalChat/tests/test_chat_template.py @@ -0,0 +1,38 @@ +from coati.dataset import setup_conversation_template +from coati.dataset.conversation import Conversation +from coati.dataset.tokenization_utils import supervised_tokenize_sft +from transformers import AutoTokenizer +import json +import os + +model_data_mapping = { + 'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json', + 'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json', + 'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json', + 'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json', + '01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json', + '01-ai/Yi-34B': '01-ai_Yi-34B.json', + 'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json', + 'microsoft/phi-2': 'microsoft_phi-2.json', + 'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json' +} +chat_template_config_path = '../config/conversation_template' + + +def test_tokenization_sft(): + for model in model_data_mapping: + print(f"#############{model}#############") + conversation_template_config = os.path.join(chat_template_config_path, model_data_mapping[model]) + messages = [{"from": "human", "content": "What are the three primary colors?"}, + {"from": "assistant", "content": "The three primary colors are red, blue, and yellow."}, + {"from": "human", "content": "解释个人电脑和服务器之间的区别。"}, + {"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}] + chat_template_config = json.load(open(conversation_template_config, "r", encoding="utf8")) + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) + conversation_template = setup_conversation_template( + tokenizer, chat_template_config=chat_template_config, save_path=conversation_template_config + ) + + output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template) + with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f: + assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed" diff --git a/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-1.5-9B-Chat.json b/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-1.5-9B-Chat.json new file mode 100644 index 000000000..52a20f813 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-1.5-9B-Chat.json @@ -0,0 +1,585 @@ +{ + "input_ids": [ + 1, + 59603, + 9334, + 1397, + 562, + 13310, + 2756, + 597, + 663, + 15874, + 10357, + 14135, + 98, + 707, + 14135, + 3641, + 6901, + 97, + 7283, + 97, + 597, + 31081, + 8476, + 592, + 567, + 2756, + 59610, + 59575, + 3275, + 98, + 144, + 144, + 6, + 3903, + 144, + 5697, + 678, + 567, + 1604, + 5789, + 7590, + 100, + 7, + 144, + 6, + 765, + 13611, + 144, + 1263, + 1604, + 5789, + 7590, + 678, + 2894, + 97, + 5083, + 97, + 597, + 10744, + 98, + 7, + 144, + 6, + 3903, + 144, + 7714, + 2897, + 11491, + 59652, + 17504, + 7125, + 13189, + 102, + 7, + 144, + 6, + 765, + 13611, + 144, + 2897, + 11491, + 59652, + 17504, + 59626, + 12295, + 2618, + 24768, + 12780, + 2644, + 101, + 6774, + 10495, + 13189, + 5505, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 59568, + 2897, + 11491, + 101, + 60604, + 59867, + 60084, + 60160, + 101, + 29874, + 2897, + 2253, + 59732, + 23806, + 12780, + 102, + 6774, + 8224, + 4983, + 7304, + 5973, + 105, + 11152, + 44353, + 101, + 1229, + 7252, + 33282, + 32371, + 59652, + 5186, + 102, + 2897, + 11491, + 59599, + 17713, + 9217, + 39485, + 4245, + 3600, + 9217, + 59670, + 23806, + 101, + 3438, + 6985, + 2749, + 2897, + 4195, + 1540, + 23525, + 102, + 59568, + 59732, + 17504, + 15703, + 6748, + 8262, + 4576, + 17107, + 59732, + 23806, + 12780, + 2644, + 101, + 6774, + 8224, + 4983, + 59648, + 4576, + 2479, + 4207, + 4014, + 1893, + 101, + 59738, + 8276, + 105, + 5601, + 24111, + 59652, + 6280, + 18192, + 59748, + 102, + 17504, + 8224, + 1867, + 57898, + 59599, + 17713, + 9217, + 101, + 4502, + 1229, + 17456, + 59719, + 38298, + 59652, + 59509, + 7252, + 102, + 3574, + 17504, + 1867, + 3392, + 8262, + 28830, + 15595, + 101, + 6774, + 8224, + 18802, + 59691, + 60326, + 25829, + 105, + 59647, + 22981, + 25145, + 24724, + 22981, + 43452, + 11553, + 60101, + 101, + 59659, + 3206, + 13890, + 7252, + 7042, + 59652, + 26815, + 102, + 59568, + 20237, + 101, + 2897, + 11491, + 59652, + 17504, + 7125, + 2461, + 13189, + 5505, + 30281, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 2897, + 11491, + 4983, + 2897, + 2253, + 101, + 59732, + 17504, + 4983, + 3392, + 8262, + 28830, + 15595, + 102, + 1893, + 12507, + 17713, + 9217, + 8224, + 59806, + 2897, + 11491, + 16549, + 101, + 59659, + 5731, + 13890, + 9176, + 59652, + 26815, + 102, + 7 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 1263, + 1604, + 5789, + 7590, + 678, + 2894, + 97, + 5083, + 97, + 597, + 10744, + 98, + 7, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 2897, + 11491, + 59652, + 17504, + 59626, + 12295, + 2618, + 24768, + 12780, + 2644, + 101, + 6774, + 10495, + 13189, + 5505, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 59568, + 2897, + 11491, + 101, + 60604, + 59867, + 60084, + 60160, + 101, + 29874, + 2897, + 2253, + 59732, + 23806, + 12780, + 102, + 6774, + 8224, + 4983, + 7304, + 5973, + 105, + 11152, + 44353, + 101, + 1229, + 7252, + 33282, + 32371, + 59652, + 5186, + 102, + 2897, + 11491, + 59599, + 17713, + 9217, + 39485, + 4245, + 3600, + 9217, + 59670, + 23806, + 101, + 3438, + 6985, + 2749, + 2897, + 4195, + 1540, + 23525, + 102, + 59568, + 59732, + 17504, + 15703, + 6748, + 8262, + 4576, + 17107, + 59732, + 23806, + 12780, + 2644, + 101, + 6774, + 8224, + 4983, + 59648, + 4576, + 2479, + 4207, + 4014, + 1893, + 101, + 59738, + 8276, + 105, + 5601, + 24111, + 59652, + 6280, + 18192, + 59748, + 102, + 17504, + 8224, + 1867, + 57898, + 59599, + 17713, + 9217, + 101, + 4502, + 1229, + 17456, + 59719, + 38298, + 59652, + 59509, + 7252, + 102, + 3574, + 17504, + 1867, + 3392, + 8262, + 28830, + 15595, + 101, + 6774, + 8224, + 18802, + 59691, + 60326, + 25829, + 105, + 59647, + 22981, + 25145, + 24724, + 22981, + 43452, + 11553, + 60101, + 101, + 59659, + 3206, + 13890, + 7252, + 7042, + 59652, + 26815, + 102, + 59568, + 20237, + 101, + 2897, + 11491, + 59652, + 17504, + 7125, + 2461, + 13189, + 5505, + 30281, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 2897, + 11491, + 4983, + 2897, + 2253, + 101, + 59732, + 17504, + 4983, + 3392, + 8262, + 28830, + 15595, + 102, + 1893, + 12507, + 17713, + 9217, + 8224, + 59806, + 2897, + 11491, + 16549, + 101, + 59659, + 5731, + 13890, + 9176, + 59652, + 26815, + 102, + 7 + ], + "inputs_decode": "<|startoftext|> A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_start|>user\nWhat are the three primary colors?<|im_end|> \n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|> \n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|> \n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>", + "labels_decode": [ + " The three primary colors are red, blue, and yellow.<|im_end|>", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>" + ], + "seq_length": 286, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-34B.json b/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-34B.json new file mode 100644 index 000000000..424a737d5 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/01-ai_Yi-34B.json @@ -0,0 +1,607 @@ +{ + "input_ids": [ + 1, + 59653, + 34448, + 59651, + 7488, + 55701, + 4422, + 144, + 59603, + 9334, + 1397, + 562, + 13310, + 2756, + 597, + 663, + 15874, + 10357, + 14135, + 98, + 707, + 14135, + 3641, + 6901, + 97, + 7283, + 97, + 597, + 31081, + 8476, + 592, + 567, + 2756, + 59610, + 59575, + 3275, + 98, + 144, + 144, + 144, + 59666, + 1359, + 55701, + 4422, + 144, + 144, + 5697, + 678, + 567, + 1604, + 5789, + 7590, + 100, + 1273, + 59598, + 34448, + 59651, + 707, + 1604, + 5789, + 7590, + 678, + 2894, + 97, + 5083, + 97, + 597, + 10744, + 98, + 59568, + 2, + 1, + 59653, + 34448, + 59651, + 59568, + 7714, + 2897, + 11491, + 59652, + 17504, + 7125, + 13189, + 102, + 1273, + 59598, + 34448, + 59651, + 59568, + 2897, + 11491, + 59652, + 17504, + 59626, + 12295, + 2618, + 24768, + 12780, + 2644, + 101, + 6774, + 10495, + 13189, + 5505, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 59568, + 2897, + 11491, + 101, + 60604, + 59867, + 60084, + 60160, + 101, + 29874, + 2897, + 2253, + 59732, + 23806, + 12780, + 102, + 6774, + 8224, + 4983, + 7304, + 5973, + 105, + 11152, + 44353, + 101, + 1229, + 7252, + 33282, + 32371, + 59652, + 5186, + 102, + 2897, + 11491, + 59599, + 17713, + 9217, + 39485, + 4245, + 3600, + 9217, + 59670, + 23806, + 101, + 3438, + 6985, + 2749, + 2897, + 4195, + 1540, + 23525, + 102, + 59568, + 59732, + 17504, + 15703, + 6748, + 8262, + 4576, + 17107, + 59732, + 23806, + 12780, + 2644, + 101, + 6774, + 8224, + 4983, + 59648, + 4576, + 2479, + 4207, + 4014, + 1893, + 101, + 59738, + 8276, + 105, + 5601, + 24111, + 59652, + 6280, + 18192, + 59748, + 102, + 17504, + 8224, + 1867, + 57898, + 59599, + 17713, + 9217, + 101, + 4502, + 1229, + 17456, + 59719, + 38298, + 59652, + 59509, + 7252, + 102, + 3574, + 17504, + 1867, + 3392, + 8262, + 28830, + 15595, + 101, + 6774, + 8224, + 18802, + 59691, + 60326, + 25829, + 105, + 59647, + 22981, + 25145, + 24724, + 22981, + 43452, + 11553, + 60101, + 101, + 59659, + 3206, + 13890, + 7252, + 7042, + 59652, + 26815, + 102, + 59568, + 20237, + 101, + 2897, + 11491, + 59652, + 17504, + 7125, + 2461, + 13189, + 5505, + 30281, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 2897, + 11491, + 4983, + 2897, + 2253, + 101, + 59732, + 17504, + 4983, + 3392, + 8262, + 28830, + 15595, + 102, + 1893, + 12507, + 17713, + 9217, + 8224, + 59806, + 2897, + 11491, + 16549, + 101, + 59659, + 5731, + 13890, + 9176, + 59652, + 26815, + 102, + 59568, + 2 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 707, + 1604, + 5789, + 7590, + 678, + 2894, + 97, + 5083, + 97, + 597, + 10744, + 98, + 59568, + 2, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 59568, + 2897, + 11491, + 59652, + 17504, + 59626, + 12295, + 2618, + 24768, + 12780, + 2644, + 101, + 6774, + 10495, + 13189, + 5505, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 59568, + 2897, + 11491, + 101, + 60604, + 59867, + 60084, + 60160, + 101, + 29874, + 2897, + 2253, + 59732, + 23806, + 12780, + 102, + 6774, + 8224, + 4983, + 7304, + 5973, + 105, + 11152, + 44353, + 101, + 1229, + 7252, + 33282, + 32371, + 59652, + 5186, + 102, + 2897, + 11491, + 59599, + 17713, + 9217, + 39485, + 4245, + 3600, + 9217, + 59670, + 23806, + 101, + 3438, + 6985, + 2749, + 2897, + 4195, + 1540, + 23525, + 102, + 59568, + 59732, + 17504, + 15703, + 6748, + 8262, + 4576, + 17107, + 59732, + 23806, + 12780, + 2644, + 101, + 6774, + 8224, + 4983, + 59648, + 4576, + 2479, + 4207, + 4014, + 1893, + 101, + 59738, + 8276, + 105, + 5601, + 24111, + 59652, + 6280, + 18192, + 59748, + 102, + 17504, + 8224, + 1867, + 57898, + 59599, + 17713, + 9217, + 101, + 4502, + 1229, + 17456, + 59719, + 38298, + 59652, + 59509, + 7252, + 102, + 3574, + 17504, + 1867, + 3392, + 8262, + 28830, + 15595, + 101, + 6774, + 8224, + 18802, + 59691, + 60326, + 25829, + 105, + 59647, + 22981, + 25145, + 24724, + 22981, + 43452, + 11553, + 60101, + 101, + 59659, + 3206, + 13890, + 7252, + 7042, + 59652, + 26815, + 102, + 59568, + 20237, + 101, + 2897, + 11491, + 59652, + 17504, + 7125, + 2461, + 13189, + 5505, + 30281, + 23337, + 105, + 17713, + 9217, + 59652, + 9176, + 102, + 2897, + 11491, + 4983, + 2897, + 2253, + 101, + 59732, + 17504, + 4983, + 3392, + 8262, + 28830, + 15595, + 102, + 1893, + 12507, + 17713, + 9217, + 8224, + 59806, + 2897, + 11491, + 16549, + 101, + 59659, + 5731, + 13890, + 9176, + 59652, + 26815, + 102, + 59568, + 2 + ], + "inputs_decode": "<|startoftext|> [INST] <>\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n\n<>\n\nWhat are the three primary colors? [/INST] The three primary colors are red, blue, and yellow. <|endoftext|><|startoftext|> [INST] 解释个人电脑和服务器之间的区别。 [/INST] 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。 <|endoftext|>", + "labels_decode": [ + " The three primary colors are red, blue, and yellow. <|endoftext|>", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。 <|endoftext|>" + ], + "seq_length": 297, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/Qwen_Qwen-7B-Chat.json b/applications/ColossalChat/tests/test_data/chat_template/Qwen_Qwen-7B-Chat.json new file mode 100644 index 000000000..1ebfac1d6 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/Qwen_Qwen-7B-Chat.json @@ -0,0 +1,603 @@ +{ + "input_ids": [ + 151644, + 8948, + 198, + 32, + 6236, + 1948, + 264, + 22208, + 3738, + 323, + 458, + 20443, + 11229, + 17847, + 13, + 576, + 17847, + 6696, + 10950, + 11, + 11682, + 11, + 323, + 47787, + 11253, + 311, + 279, + 3738, + 594, + 4755, + 382, + 151645, + 198, + 151644, + 872, + 198, + 3838, + 525, + 279, + 2326, + 6028, + 7987, + 30, + 151645, + 198, + 151644, + 77091, + 198, + 785, + 2326, + 6028, + 7987, + 525, + 2518, + 11, + 6303, + 11, + 323, + 13753, + 13, + 151645, + 198, + 151644, + 872, + 198, + 104136, + 99605, + 104145, + 33108, + 89047, + 104186, + 102665, + 1773, + 151645, + 198, + 151644, + 77091, + 198, + 99605, + 104145, + 33108, + 89047, + 20412, + 101441, + 99604, + 109963, + 104564, + 72448, + 3837, + 104017, + 104396, + 102665, + 101321, + 105795, + 5373, + 105433, + 85767, + 33108, + 102111, + 1773, + 220, + 99605, + 104145, + 3837, + 99846, + 13072, + 90663, + 64559, + 3837, + 20412, + 17714, + 99605, + 37029, + 68536, + 70500, + 9370, + 104564, + 1773, + 104017, + 102119, + 100751, + 101254, + 104066, + 5373, + 100415, + 33108, + 100134, + 3837, + 73670, + 104001, + 100646, + 99200, + 100535, + 113384, + 33108, + 99329, + 1773, + 99605, + 104145, + 9370, + 105433, + 85767, + 111071, + 101892, + 100142, + 85767, + 36407, + 70500, + 9370, + 3837, + 100632, + 104047, + 100345, + 99605, + 100354, + 71817, + 104790, + 1773, + 8908, + 222, + 234, + 89047, + 104802, + 101929, + 100722, + 20002, + 104378, + 68536, + 70500, + 9370, + 104564, + 72448, + 3837, + 104017, + 102119, + 100751, + 17714, + 110782, + 100646, + 71356, + 47874, + 3837, + 29524, + 100010, + 5373, + 116617, + 33108, + 26898, + 107468, + 49567, + 1773, + 89047, + 102119, + 85106, + 113313, + 9370, + 105433, + 85767, + 90395, + 100136, + 73670, + 105581, + 44636, + 118878, + 33108, + 102612, + 9370, + 104001, + 1773, + 101887, + 89047, + 85106, + 100143, + 100722, + 107494, + 104925, + 3837, + 104017, + 102119, + 102578, + 42140, + 71137, + 111249, + 5373, + 26288, + 106656, + 111390, + 33108, + 26288, + 106656, + 114897, + 102474, + 31548, + 3837, + 23031, + 100627, + 105743, + 104001, + 101149, + 33108, + 108239, + 1773, + 90476, + 119, + 53930, + 3837, + 99605, + 104145, + 33108, + 89047, + 104186, + 99558, + 102665, + 101321, + 104017, + 9370, + 105795, + 5373, + 105433, + 85767, + 33108, + 102111, + 1773, + 99605, + 104145, + 100751, + 99605, + 37029, + 3837, + 68536, + 89047, + 100751, + 100143, + 100722, + 107494, + 104925, + 1773, + 89047, + 9370, + 105433, + 85767, + 102119, + 56006, + 99605, + 104145, + 105377, + 3837, + 23031, + 101907, + 105743, + 102111, + 33108, + 108239, + 1773, + 151645 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 785, + 2326, + 6028, + 7987, + 525, + 2518, + 11, + 6303, + 11, + 323, + 13753, + 13, + 151645, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 99605, + 104145, + 33108, + 89047, + 20412, + 101441, + 99604, + 109963, + 104564, + 72448, + 3837, + 104017, + 104396, + 102665, + 101321, + 105795, + 5373, + 105433, + 85767, + 33108, + 102111, + 1773, + 220, + 99605, + 104145, + 3837, + 99846, + 13072, + 90663, + 64559, + 3837, + 20412, + 17714, + 99605, + 37029, + 68536, + 70500, + 9370, + 104564, + 1773, + 104017, + 102119, + 100751, + 101254, + 104066, + 5373, + 100415, + 33108, + 100134, + 3837, + 73670, + 104001, + 100646, + 99200, + 100535, + 113384, + 33108, + 99329, + 1773, + 99605, + 104145, + 9370, + 105433, + 85767, + 111071, + 101892, + 100142, + 85767, + 36407, + 70500, + 9370, + 3837, + 100632, + 104047, + 100345, + 99605, + 100354, + 71817, + 104790, + 1773, + 8908, + 222, + 234, + 89047, + 104802, + 101929, + 100722, + 20002, + 104378, + 68536, + 70500, + 9370, + 104564, + 72448, + 3837, + 104017, + 102119, + 100751, + 17714, + 110782, + 100646, + 71356, + 47874, + 3837, + 29524, + 100010, + 5373, + 116617, + 33108, + 26898, + 107468, + 49567, + 1773, + 89047, + 102119, + 85106, + 113313, + 9370, + 105433, + 85767, + 90395, + 100136, + 73670, + 105581, + 44636, + 118878, + 33108, + 102612, + 9370, + 104001, + 1773, + 101887, + 89047, + 85106, + 100143, + 100722, + 107494, + 104925, + 3837, + 104017, + 102119, + 102578, + 42140, + 71137, + 111249, + 5373, + 26288, + 106656, + 111390, + 33108, + 26288, + 106656, + 114897, + 102474, + 31548, + 3837, + 23031, + 100627, + 105743, + 104001, + 101149, + 33108, + 108239, + 1773, + 90476, + 119, + 53930, + 3837, + 99605, + 104145, + 33108, + 89047, + 104186, + 99558, + 102665, + 101321, + 104017, + 9370, + 105795, + 5373, + 105433, + 85767, + 33108, + 102111, + 1773, + 99605, + 104145, + 100751, + 99605, + 37029, + 3837, + 68536, + 89047, + 100751, + 100143, + 100722, + 107494, + 104925, + 1773, + 89047, + 9370, + 105433, + 85767, + 102119, + 56006, + 99605, + 104145, + 105377, + 3837, + 23031, + 101907, + 105743, + 102111, + 33108, + 108239, + 1773, + 151645 + ], + "inputs_decode": "<|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\nWhat are the three primary colors?<|im_end|>\n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|>\n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|>\n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>", + "labels_decode": [ + "The three primary colors are red, blue, and yellow.<|im_end|>", + "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>" + ], + "seq_length": 295, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm2-6b.json b/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm2-6b.json new file mode 100644 index 000000000..bc421f470 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm2-6b.json @@ -0,0 +1,715 @@ +{ + "input_ids": [ + 906, + 31007, + 326, + 30962, + 6631, + 31007, + 30994, + 13361, + 13, + 30938, + 8911, + 1074, + 260, + 10976, + 1869, + 293, + 284, + 12060, + 7151, + 9319, + 30930, + 353, + 9319, + 3510, + 6483, + 30932, + 6374, + 30932, + 293, + 25712, + 7115, + 289, + 267, + 1869, + 30953, + 30917, + 2554, + 30930, + 13, + 13, + 31002, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 13, + 31002, + 31007, + 326, + 30962, + 6631, + 31007, + 30994, + 4865, + 13, + 1266, + 383, + 267, + 1194, + 4390, + 7129, + 30987, + 906, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 13, + 31002, + 31007, + 326, + 30962, + 6631, + 31007, + 30994, + 530, + 18971, + 13, + 353, + 1194, + 4390, + 7129, + 383, + 2603, + 30932, + 4610, + 30932, + 293, + 8362, + 30930, + 31002, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 30910, + 13, + 31002, + 31007, + 326, + 30962, + 6631, + 31007, + 30994, + 4865, + 13, + 30910, + 32929, + 31745, + 34032, + 37154, + 55083, + 32695, + 34213, + 31155, + 906, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 13, + 31002, + 31007, + 326, + 30962, + 6631, + 31007, + 30994, + 530, + 18971, + 13, + 30910, + 31745, + 34032, + 37154, + 55083, + 54532, + 33701, + 31723, + 39141, + 33606, + 31739, + 31123, + 32542, + 33248, + 34213, + 33260, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 30910, + 31745, + 34032, + 31123, + 55502, + 54653, + 54872, + 54923, + 31123, + 38391, + 31745, + 31695, + 54617, + 37538, + 33606, + 31155, + 32542, + 32955, + 32419, + 32432, + 32249, + 31201, + 32645, + 44505, + 31123, + 31628, + 32557, + 41054, + 43569, + 54542, + 32033, + 31155, + 31745, + 34032, + 54530, + 36073, + 33339, + 47907, + 32001, + 31858, + 33339, + 54556, + 37538, + 31123, + 31925, + 32591, + 31793, + 31745, + 31987, + 31636, + 36058, + 31155, + 43833, + 38107, + 35231, + 32454, + 32325, + 32053, + 35299, + 54617, + 37538, + 33606, + 31739, + 31123, + 32542, + 32955, + 32419, + 54541, + 49193, + 31986, + 31863, + 31645, + 31123, + 54627, + 32438, + 31201, + 48747, + 54542, + 32410, + 38520, + 54609, + 31155, + 38107, + 32955, + 31665, + 46294, + 54530, + 36073, + 33339, + 31123, + 32187, + 31628, + 35757, + 54589, + 54376, + 54542, + 53144, + 32557, + 31155, + 31949, + 38107, + 31665, + 31818, + 32325, + 39195, + 33967, + 31123, + 32542, + 32955, + 34881, + 54573, + 55110, + 39727, + 31201, + 54539, + 36608, + 40538, + 37251, + 36608, + 52898, + 34861, + 55083, + 31123, + 54548, + 31803, + 34650, + 32557, + 32615, + 54542, + 39941, + 31155, + 30910, + 37169, + 31123, + 31745, + 34032, + 37154, + 55083, + 32695, + 31703, + 34213, + 33260, + 41215, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 31745, + 34032, + 32419, + 31745, + 31695, + 31123, + 54617, + 38107, + 32419, + 31818, + 32325, + 39195, + 33967, + 31155, + 31645, + 35949, + 36073, + 33339, + 32955, + 54703, + 31745, + 34032, + 34732, + 31123, + 54548, + 32444, + 34650, + 33116, + 54542, + 39941, + 31155, + 31002, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 2 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 353, + 1194, + 4390, + 7129, + 383, + 2603, + 30932, + 4610, + 30932, + 293, + 8362, + 30930, + 31002, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 30910, + 31745, + 34032, + 37154, + 55083, + 54532, + 33701, + 31723, + 39141, + 33606, + 31739, + 31123, + 32542, + 33248, + 34213, + 33260, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 30910, + 31745, + 34032, + 31123, + 55502, + 54653, + 54872, + 54923, + 31123, + 38391, + 31745, + 31695, + 54617, + 37538, + 33606, + 31155, + 32542, + 32955, + 32419, + 32432, + 32249, + 31201, + 32645, + 44505, + 31123, + 31628, + 32557, + 41054, + 43569, + 54542, + 32033, + 31155, + 31745, + 34032, + 54530, + 36073, + 33339, + 47907, + 32001, + 31858, + 33339, + 54556, + 37538, + 31123, + 31925, + 32591, + 31793, + 31745, + 31987, + 31636, + 36058, + 31155, + 43833, + 38107, + 35231, + 32454, + 32325, + 32053, + 35299, + 54617, + 37538, + 33606, + 31739, + 31123, + 32542, + 32955, + 32419, + 54541, + 49193, + 31986, + 31863, + 31645, + 31123, + 54627, + 32438, + 31201, + 48747, + 54542, + 32410, + 38520, + 54609, + 31155, + 38107, + 32955, + 31665, + 46294, + 54530, + 36073, + 33339, + 31123, + 32187, + 31628, + 35757, + 54589, + 54376, + 54542, + 53144, + 32557, + 31155, + 31949, + 38107, + 31665, + 31818, + 32325, + 39195, + 33967, + 31123, + 32542, + 32955, + 34881, + 54573, + 55110, + 39727, + 31201, + 54539, + 36608, + 40538, + 37251, + 36608, + 52898, + 34861, + 55083, + 31123, + 54548, + 31803, + 34650, + 32557, + 32615, + 54542, + 39941, + 31155, + 30910, + 37169, + 31123, + 31745, + 34032, + 37154, + 55083, + 32695, + 31703, + 34213, + 33260, + 41215, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 31745, + 34032, + 32419, + 31745, + 31695, + 31123, + 54617, + 38107, + 32419, + 31818, + 32325, + 39195, + 33967, + 31155, + 31645, + 35949, + 36073, + 33339, + 32955, + 54703, + 31745, + 34032, + 34732, + 31123, + 54548, + 32444, + 34650, + 33116, + 54542, + 39941, + 31155, + 31002, + 31007, + 326, + 30962, + 437, + 31007, + 30994, + 2 + ], + "inputs_decode": "<|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\n What are the three primary colors? <|im_end|>\n<|im_start|>assistant\n The three primary colors are red, blue, and yellow.<|im_end|> \n<|im_start|>user\n 解释个人电脑和服务器之间的区别。 <|im_end|>\n<|im_start|>assistant\n 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>", + "labels_decode": [ + "The three primary colors are red, blue, and yellow.<|im_end|>", + "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>" + ], + "seq_length": 351, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm3-6b.json b/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm3-6b.json new file mode 100644 index 000000000..5aab0073e --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/THUDM_chatglm3-6b.json @@ -0,0 +1,585 @@ +{ + "input_ids": [ + 64790, + 64792, + 906, + 31007, + 13361, + 31007, + 30994, + 13, + 316, + 8911, + 1074, + 260, + 10976, + 1869, + 293, + 284, + 12060, + 7151, + 9319, + 30930, + 353, + 9319, + 3510, + 6483, + 30932, + 6374, + 30932, + 293, + 25712, + 7115, + 289, + 267, + 1869, + 30953, + 30917, + 2554, + 30930, + 13, + 13, + 64795, + 30910, + 13, + 265, + 5011, + 383, + 267, + 1194, + 4390, + 7129, + 30987, + 64796, + 30910, + 13, + 265, + 1036, + 1194, + 4390, + 7129, + 383, + 2603, + 30932, + 4610, + 30932, + 293, + 8362, + 30930, + 64795, + 30910, + 13, + 265, + 32929, + 31745, + 34032, + 37154, + 55083, + 32695, + 34213, + 31155, + 64796, + 30910, + 13, + 265, + 31745, + 34032, + 37154, + 55083, + 54532, + 33701, + 31723, + 39141, + 33606, + 31739, + 31123, + 32542, + 33248, + 34213, + 33260, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 30910, + 31745, + 34032, + 31123, + 55502, + 54653, + 54872, + 54923, + 31123, + 38391, + 31745, + 31695, + 54617, + 37538, + 33606, + 31155, + 32542, + 32955, + 32419, + 32432, + 32249, + 31201, + 32645, + 44505, + 31123, + 31628, + 32557, + 41054, + 43569, + 54542, + 32033, + 31155, + 31745, + 34032, + 54530, + 36073, + 33339, + 47907, + 32001, + 31858, + 33339, + 54556, + 37538, + 31123, + 31925, + 32591, + 31793, + 31745, + 31987, + 31636, + 36058, + 31155, + 43833, + 38107, + 35231, + 32454, + 32325, + 32053, + 35299, + 54617, + 37538, + 33606, + 31739, + 31123, + 32542, + 32955, + 32419, + 54541, + 49193, + 31986, + 31863, + 31645, + 31123, + 54627, + 32438, + 31201, + 48747, + 54542, + 32410, + 38520, + 54609, + 31155, + 38107, + 32955, + 31665, + 46294, + 54530, + 36073, + 33339, + 31123, + 32187, + 31628, + 35757, + 54589, + 54376, + 54542, + 53144, + 32557, + 31155, + 31949, + 38107, + 31665, + 31818, + 32325, + 39195, + 33967, + 31123, + 32542, + 32955, + 34881, + 54573, + 55110, + 39727, + 31201, + 54539, + 36608, + 40538, + 37251, + 36608, + 52898, + 34861, + 55083, + 31123, + 54548, + 31803, + 34650, + 32557, + 32615, + 54542, + 39941, + 31155, + 30910, + 37169, + 31123, + 31745, + 34032, + 37154, + 55083, + 32695, + 31703, + 34213, + 33260, + 41215, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 31745, + 34032, + 32419, + 31745, + 31695, + 31123, + 54617, + 38107, + 32419, + 31818, + 32325, + 39195, + 33967, + 31155, + 31645, + 35949, + 36073, + 33339, + 32955, + 54703, + 31745, + 34032, + 34732, + 31123, + 54548, + 32444, + 34650, + 33116, + 54542, + 39941, + 31155, + 2 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 265, + 1036, + 1194, + 4390, + 7129, + 383, + 2603, + 30932, + 4610, + 30932, + 293, + 8362, + 30930, + 64795, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 265, + 31745, + 34032, + 37154, + 55083, + 54532, + 33701, + 31723, + 39141, + 33606, + 31739, + 31123, + 32542, + 33248, + 34213, + 33260, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 30910, + 31745, + 34032, + 31123, + 55502, + 54653, + 54872, + 54923, + 31123, + 38391, + 31745, + 31695, + 54617, + 37538, + 33606, + 31155, + 32542, + 32955, + 32419, + 32432, + 32249, + 31201, + 32645, + 44505, + 31123, + 31628, + 32557, + 41054, + 43569, + 54542, + 32033, + 31155, + 31745, + 34032, + 54530, + 36073, + 33339, + 47907, + 32001, + 31858, + 33339, + 54556, + 37538, + 31123, + 31925, + 32591, + 31793, + 31745, + 31987, + 31636, + 36058, + 31155, + 43833, + 38107, + 35231, + 32454, + 32325, + 32053, + 35299, + 54617, + 37538, + 33606, + 31739, + 31123, + 32542, + 32955, + 32419, + 54541, + 49193, + 31986, + 31863, + 31645, + 31123, + 54627, + 32438, + 31201, + 48747, + 54542, + 32410, + 38520, + 54609, + 31155, + 38107, + 32955, + 31665, + 46294, + 54530, + 36073, + 33339, + 31123, + 32187, + 31628, + 35757, + 54589, + 54376, + 54542, + 53144, + 32557, + 31155, + 31949, + 38107, + 31665, + 31818, + 32325, + 39195, + 33967, + 31123, + 32542, + 32955, + 34881, + 54573, + 55110, + 39727, + 31201, + 54539, + 36608, + 40538, + 37251, + 36608, + 52898, + 34861, + 55083, + 31123, + 54548, + 31803, + 34650, + 32557, + 32615, + 54542, + 39941, + 31155, + 30910, + 37169, + 31123, + 31745, + 34032, + 37154, + 55083, + 32695, + 31703, + 34213, + 33260, + 41215, + 35687, + 31201, + 36073, + 33339, + 54542, + 33116, + 31155, + 31745, + 34032, + 32419, + 31745, + 31695, + 31123, + 54617, + 38107, + 32419, + 31818, + 32325, + 39195, + 33967, + 31155, + 31645, + 35949, + 36073, + 33339, + 32955, + 54703, + 31745, + 34032, + 34732, + 31123, + 54548, + 32444, + 34650, + 33116, + 54542, + 39941, + 31155, + 2 + ], + "inputs_decode": "[gMASK] sop <|system|>\n A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n <|user|> \n What are the three primary colors? <|assistant|> \n The three primary colors are red, blue, and yellow. <|user|> \n 解释个人电脑和服务器之间的区别。 <|assistant|> \n 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。", + "labels_decode": [ + " The three primary colors are red, blue, and yellow. <|user|>", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。" + ], + "seq_length": 286, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/baichuan-inc_Baichuan2-13B-Chat.json b/applications/ColossalChat/tests/test_data/chat_template/baichuan-inc_Baichuan2-13B-Chat.json new file mode 100644 index 000000000..736baf85b --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/baichuan-inc_Baichuan2-13B-Chat.json @@ -0,0 +1,697 @@ +{ + "input_ids": [ + 1, + 92655, + 92647, + 1418, + 92484, + 13387, + 92647, + 92574, + 24399, + 5, + 92343, + 15161, + 2357, + 1346, + 23815, + 3558, + 1377, + 1452, + 19649, + 11656, + 14002, + 72, + 1481, + 14002, + 6474, + 13629, + 92323, + 11144, + 92323, + 1377, + 70217, + 13514, + 1375, + 1352, + 3558, + 92404, + 92319, + 4852, + 72, + 5, + 5, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 5, + 92655, + 92647, + 1418, + 92484, + 13387, + 92647, + 92574, + 8589, + 5, + 3167, + 1484, + 1352, + 2397, + 7721, + 12654, + 74, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 5, + 92655, + 92647, + 1418, + 92484, + 13387, + 92647, + 92574, + 1613, + 10685, + 5, + 1524, + 2397, + 7721, + 12654, + 1484, + 5005, + 92323, + 8488, + 92323, + 1377, + 16149, + 72, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 5, + 92655, + 92647, + 1418, + 92484, + 13387, + 92647, + 92574, + 8589, + 5, + 5987, + 2076, + 7283, + 18431, + 93107, + 5332, + 8592, + 66, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 5, + 92655, + 92647, + 1418, + 92484, + 13387, + 92647, + 92574, + 1613, + 10685, + 5, + 2076, + 7283, + 18431, + 93107, + 92347, + 8653, + 2381, + 22152, + 9403, + 2274, + 65, + 20962, + 2231, + 8592, + 6795, + 13087, + 69, + 18512, + 7587, + 92385, + 8222, + 66, + 92311, + 2076, + 7283, + 65, + 52905, + 65, + 23044, + 2076, + 1964, + 92492, + 17274, + 9403, + 66, + 6085, + 6984, + 4349, + 4161, + 4588, + 69, + 5748, + 30443, + 65, + 1583, + 5155, + 24491, + 48901, + 92385, + 3872, + 66, + 2076, + 7283, + 92333, + 18512, + 7587, + 28068, + 2592, + 2417, + 7587, + 92393, + 17274, + 65, + 2776, + 4463, + 78704, + 3173, + 1697, + 9501, + 66, + 100178, + 31391, + 10551, + 4126, + 4121, + 3723, + 13169, + 92492, + 17274, + 9403, + 2274, + 65, + 6085, + 6984, + 4349, + 57325, + 2099, + 2936, + 2012, + 1760, + 65, + 92467, + 4516, + 69, + 36308, + 92385, + 4733, + 25092, + 92457, + 66, + 31391, + 6984, + 1836, + 49902, + 92333, + 18512, + 7587, + 65, + 3741, + 1583, + 11871, + 92434, + 55317, + 92385, + 41203, + 5155, + 66, + 3035, + 31391, + 1836, + 2590, + 4121, + 30881, + 17136, + 65, + 6085, + 6984, + 10040, + 92418, + 92888, + 40793, + 69, + 92366, + 18178, + 32629, + 18787, + 18178, + 66897, + 14685, + 93107, + 65, + 73411, + 12452, + 5155, + 82905, + 23304, + 66, + 92311, + 14748, + 65, + 2076, + 7283, + 18431, + 93107, + 5332, + 2231, + 8592, + 6795, + 20962, + 13087, + 69, + 18512, + 7587, + 92385, + 8222, + 66, + 2076, + 7283, + 4349, + 2076, + 1964, + 65, + 92492, + 31391, + 4349, + 2590, + 4121, + 30881, + 17136, + 66, + 1760, + 21050, + 18512, + 7587, + 6984, + 92596, + 2076, + 7283, + 10680, + 65, + 57310, + 12452, + 8222, + 92385, + 23304, + 66, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 2 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 1524, + 2397, + 7721, + 12654, + 1484, + 5005, + 92323, + 8488, + 92323, + 1377, + 16149, + 72, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 2076, + 7283, + 18431, + 93107, + 92347, + 8653, + 2381, + 22152, + 9403, + 2274, + 65, + 20962, + 2231, + 8592, + 6795, + 13087, + 69, + 18512, + 7587, + 92385, + 8222, + 66, + 92311, + 2076, + 7283, + 65, + 52905, + 65, + 23044, + 2076, + 1964, + 92492, + 17274, + 9403, + 66, + 6085, + 6984, + 4349, + 4161, + 4588, + 69, + 5748, + 30443, + 65, + 1583, + 5155, + 24491, + 48901, + 92385, + 3872, + 66, + 2076, + 7283, + 92333, + 18512, + 7587, + 28068, + 2592, + 2417, + 7587, + 92393, + 17274, + 65, + 2776, + 4463, + 78704, + 3173, + 1697, + 9501, + 66, + 100178, + 31391, + 10551, + 4126, + 4121, + 3723, + 13169, + 92492, + 17274, + 9403, + 2274, + 65, + 6085, + 6984, + 4349, + 57325, + 2099, + 2936, + 2012, + 1760, + 65, + 92467, + 4516, + 69, + 36308, + 92385, + 4733, + 25092, + 92457, + 66, + 31391, + 6984, + 1836, + 49902, + 92333, + 18512, + 7587, + 65, + 3741, + 1583, + 11871, + 92434, + 55317, + 92385, + 41203, + 5155, + 66, + 3035, + 31391, + 1836, + 2590, + 4121, + 30881, + 17136, + 65, + 6085, + 6984, + 10040, + 92418, + 92888, + 40793, + 69, + 92366, + 18178, + 32629, + 18787, + 18178, + 66897, + 14685, + 93107, + 65, + 73411, + 12452, + 5155, + 82905, + 23304, + 66, + 92311, + 14748, + 65, + 2076, + 7283, + 18431, + 93107, + 5332, + 2231, + 8592, + 6795, + 20962, + 13087, + 69, + 18512, + 7587, + 92385, + 8222, + 66, + 2076, + 7283, + 4349, + 2076, + 1964, + 65, + 92492, + 31391, + 4349, + 2590, + 4121, + 30881, + 17136, + 66, + 1760, + 21050, + 18512, + 7587, + 6984, + 92596, + 2076, + 7283, + 10680, + 65, + 57310, + 12452, + 8222, + 92385, + 23304, + 66, + 92655, + 92647, + 1418, + 92484, + 1520, + 92647, + 92574, + 2 + ], + "inputs_decode": " <|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\nWhat are the three primary colors?<|im_end|>\n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|>\n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|>\n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>", + "labels_decode": [ + "The three primary colors are red, blue, and yellow.<|im_end|>", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>" + ], + "seq_length": 342, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/deepseek-ai_DeepSeek-V2-Lite.json b/applications/ColossalChat/tests/test_data/chat_template/deepseek-ai_DeepSeek-V2-Lite.json new file mode 100644 index 000000000..546e95144 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/deepseek-ai_DeepSeek-V2-Lite.json @@ -0,0 +1,581 @@ +{ + "input_ids": [ + 100000, + 32, + 12465, + 1439, + 245, + 13076, + 3807, + 285, + 274, + 18050, + 15141, + 20308, + 13, + 429, + 20308, + 4380, + 9394, + 11, + 9333, + 11, + 285, + 30513, + 9789, + 276, + 254, + 3807, + 6, + 82, + 4313, + 13, + 185, + 185, + 185, + 185, + 5726, + 25, + 2461, + 418, + 254, + 1853, + 6663, + 8247, + 30, + 185, + 185, + 77398, + 25, + 429, + 1853, + 6663, + 8247, + 418, + 3074, + 11, + 5501, + 11, + 285, + 10421, + 13, + 100001, + 5726, + 25, + 207, + 17882, + 6213, + 19462, + 50209, + 4625, + 17901, + 25628, + 398, + 185, + 185, + 77398, + 25, + 207, + 6213, + 19462, + 50209, + 4625, + 504, + 23807, + 5871, + 53947, + 33939, + 5804, + 19304, + 15047, + 24291, + 25628, + 18752, + 48548, + 537, + 34854, + 15477, + 885, + 16938, + 398, + 207, + 6213, + 19462, + 19304, + 8549, + 2217, + 4118, + 4800, + 19304, + 54533, + 6213, + 5118, + 1306, + 42753, + 33939, + 398, + 15047, + 18867, + 15055, + 16476, + 15982, + 537, + 15615, + 885, + 5782, + 19304, + 1876, + 17135, + 60574, + 98238, + 885, + 8340, + 398, + 6213, + 19462, + 337, + 34854, + 15477, + 72652, + 10263, + 7992, + 15477, + 913, + 42753, + 19304, + 6117, + 13742, + 6857, + 6213, + 9384, + 3421, + 31448, + 398, + 207, + 1306, + 59828, + 31759, + 13211, + 16690, + 8452, + 32512, + 1306, + 42753, + 33939, + 5804, + 19304, + 15047, + 18867, + 15055, + 78963, + 6037, + 8510, + 8083, + 4076, + 19304, + 1415, + 16196, + 537, + 12385, + 53913, + 885, + 18775, + 58875, + 1537, + 398, + 59828, + 18867, + 4012, + 93589, + 337, + 34854, + 15477, + 19304, + 9482, + 1876, + 33688, + 1331, + 5865, + 7438, + 885, + 88642, + 17135, + 398, + 7406, + 59828, + 4012, + 8580, + 16690, + 47322, + 45098, + 19304, + 15047, + 18867, + 27061, + 1059, + 6273, + 47500, + 537, + 748, + 45003, + 56873, + 47107, + 45003, + 81005, + 34136, + 4625, + 19304, + 851, + 8039, + 32672, + 17135, + 13951, + 885, + 59649, + 398, + 207, + 39736, + 19304, + 6213, + 19462, + 50209, + 4625, + 17901, + 5649, + 25628, + 18752, + 63889, + 48548, + 537, + 34854, + 15477, + 885, + 16938, + 398, + 6213, + 19462, + 15055, + 6213, + 5118, + 19304, + 1306, + 59828, + 15055, + 8580, + 16690, + 47322, + 45098, + 398, + 4076, + 40542, + 34854, + 15477, + 18867, + 1769, + 6213, + 19462, + 31729, + 19304, + 851, + 12383, + 32672, + 16938, + 885, + 59649, + 398, + 100001 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 429, + 1853, + 6663, + 8247, + 418, + 3074, + 11, + 5501, + 11, + 285, + 10421, + 13, + 100001, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 207, + 6213, + 19462, + 50209, + 4625, + 504, + 23807, + 5871, + 53947, + 33939, + 5804, + 19304, + 15047, + 24291, + 25628, + 18752, + 48548, + 537, + 34854, + 15477, + 885, + 16938, + 398, + 207, + 6213, + 19462, + 19304, + 8549, + 2217, + 4118, + 4800, + 19304, + 54533, + 6213, + 5118, + 1306, + 42753, + 33939, + 398, + 15047, + 18867, + 15055, + 16476, + 15982, + 537, + 15615, + 885, + 5782, + 19304, + 1876, + 17135, + 60574, + 98238, + 885, + 8340, + 398, + 6213, + 19462, + 337, + 34854, + 15477, + 72652, + 10263, + 7992, + 15477, + 913, + 42753, + 19304, + 6117, + 13742, + 6857, + 6213, + 9384, + 3421, + 31448, + 398, + 207, + 1306, + 59828, + 31759, + 13211, + 16690, + 8452, + 32512, + 1306, + 42753, + 33939, + 5804, + 19304, + 15047, + 18867, + 15055, + 78963, + 6037, + 8510, + 8083, + 4076, + 19304, + 1415, + 16196, + 537, + 12385, + 53913, + 885, + 18775, + 58875, + 1537, + 398, + 59828, + 18867, + 4012, + 93589, + 337, + 34854, + 15477, + 19304, + 9482, + 1876, + 33688, + 1331, + 5865, + 7438, + 885, + 88642, + 17135, + 398, + 7406, + 59828, + 4012, + 8580, + 16690, + 47322, + 45098, + 19304, + 15047, + 18867, + 27061, + 1059, + 6273, + 47500, + 537, + 748, + 45003, + 56873, + 47107, + 45003, + 81005, + 34136, + 4625, + 19304, + 851, + 8039, + 32672, + 17135, + 13951, + 885, + 59649, + 398, + 207, + 39736, + 19304, + 6213, + 19462, + 50209, + 4625, + 17901, + 5649, + 25628, + 18752, + 63889, + 48548, + 537, + 34854, + 15477, + 885, + 16938, + 398, + 6213, + 19462, + 15055, + 6213, + 5118, + 19304, + 1306, + 59828, + 15055, + 8580, + 16690, + 47322, + 45098, + 398, + 4076, + 40542, + 34854, + 15477, + 18867, + 1769, + 6213, + 19462, + 31729, + 19304, + 851, + 12383, + 32672, + 16938, + 885, + 59649, + 398, + 100001 + ], + "inputs_decode": "<|begin▁of▁sentence|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n\n\nUser: What are the three primary colors?\n\nAssistant: The three primary colors are red, blue, and yellow.<|end▁of▁sentence|>User: 解释个人电脑和服务器之间的区别。\n\nAssistant: 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|end▁of▁sentence|>", + "labels_decode": [ + " The three primary colors are red, blue, and yellow.<|end▁of▁sentence|>", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|end▁of▁sentence|>" + ], + "seq_length": 284, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/microsoft_phi-2.json b/applications/ColossalChat/tests/test_data/chat_template/microsoft_phi-2.json new file mode 100644 index 000000000..f43ab7f4c --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/microsoft_phi-2.json @@ -0,0 +1,2009 @@ +{ + "input_ids": [ + 50256, + 27, + 91, + 320, + 62, + 9688, + 91, + 29, + 10057, + 198, + 32, + 8537, + 1022, + 257, + 11040, + 1692, + 290, + 281, + 11666, + 4430, + 8796, + 13, + 383, + 8796, + 3607, + 7613, + 11, + 6496, + 11, + 290, + 23507, + 7429, + 284, + 262, + 1692, + 338, + 2683, + 13, + 198, + 198, + 27, + 91, + 320, + 62, + 437, + 91, + 29, + 198, + 27, + 91, + 320, + 62, + 9688, + 91, + 29, + 7220, + 198, + 2061, + 389, + 262, + 1115, + 4165, + 7577, + 30, + 27, + 91, + 320, + 62, + 437, + 91, + 29, + 198, + 27, + 91, + 320, + 62, + 9688, + 91, + 29, + 562, + 10167, + 198, + 464, + 1115, + 4165, + 7577, + 389, + 2266, + 11, + 4171, + 11, + 290, + 7872, + 29847, + 91, + 320, + 62, + 437, + 91, + 29, + 198, + 27, + 91, + 320, + 62, + 9688, + 91, + 29, + 7220, + 198, + 164, + 100, + 96, + 34932, + 232, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 161, + 240, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 45298, + 29785, + 112, + 21410, + 44293, + 118, + 26344, + 104, + 16764, + 27, + 91, + 320, + 62, + 437, + 91, + 29, + 198, + 27, + 91, + 320, + 62, + 9688, + 91, + 29, + 562, + 10167, + 198, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 161, + 240, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 42468, + 10310, + 97, + 163, + 100, + 235, + 38834, + 28938, + 234, + 163, + 109, + 119, + 161, + 252, + 233, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 163, + 111, + 119, + 163, + 119, + 253, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 21410, + 10310, + 119, + 17358, + 223, + 44293, + 118, + 26344, + 104, + 28839, + 101, + 12859, + 236, + 18796, + 101, + 34460, + 242, + 23513, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 161, + 240, + 234, + 45250, + 100, + 47797, + 121, + 16764, + 220, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 171, + 120, + 234, + 165, + 94, + 122, + 28938, + 235, + 45250, + 251, + 20046, + 231, + 171, + 120, + 234, + 42468, + 10310, + 118, + 10310, + 103, + 21689, + 45635, + 18796, + 101, + 32003, + 234, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 16764, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 18796, + 101, + 12859, + 236, + 33768, + 98, + 30585, + 116, + 21410, + 32432, + 98, + 43291, + 23513, + 161, + 101, + 109, + 20046, + 238, + 161, + 240, + 234, + 27764, + 99, + 20046, + 254, + 171, + 120, + 234, + 20998, + 107, + 20015, + 98, + 32573, + 238, + 26193, + 234, + 28938, + 226, + 163, + 100, + 235, + 28938, + 226, + 43718, + 115, + 21410, + 41753, + 242, + 18796, + 101, + 163, + 101, + 233, + 41753, + 237, + 161, + 240, + 234, + 162, + 116, + 116, + 22755, + 237, + 16764, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 31660, + 48958, + 105, + 42468, + 162, + 234, + 231, + 163, + 227, + 100, + 43718, + 229, + 49035, + 228, + 165, + 227, + 235, + 163, + 121, + 106, + 30266, + 98, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 171, + 120, + 234, + 38834, + 32573, + 229, + 20046, + 253, + 20998, + 107, + 20015, + 98, + 43718, + 117, + 162, + 235, + 106, + 10310, + 103, + 21689, + 165, + 250, + 222, + 162, + 109, + 224, + 32573, + 249, + 26193, + 234, + 22522, + 248, + 26344, + 114, + 16764, + 5525, + 222, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 42468, + 10310, + 118, + 12859, + 228, + 162, + 119, + 94, + 164, + 114, + 111, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 165, + 250, + 222, + 162, + 109, + 224, + 32003, + 234, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 163, + 111, + 119, + 163, + 119, + 253, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 18796, + 101, + 12859, + 236, + 10310, + 118, + 18796, + 101, + 22755, + 115, + 162, + 237, + 238, + 160, + 122, + 249, + 28938, + 226, + 163, + 100, + 235, + 163, + 121, + 239, + 163, + 119, + 250, + 17312, + 235, + 27950, + 94, + 171, + 120, + 234, + 36685, + 224, + 163, + 121, + 239, + 44165, + 247, + 23513, + 18796, + 113, + 36310, + 165, + 224, + 106, + 20015, + 114, + 161, + 240, + 234, + 23877, + 229, + 20015, + 114, + 27670, + 254, + 164, + 122, + 241, + 163, + 255, + 231, + 16764, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 34460, + 248, + 30585, + 116, + 165, + 250, + 222, + 17358, + 223, + 165, + 45865, + 45250, + 100, + 47797, + 121, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 171, + 120, + 234, + 33176, + 114, + 10310, + 242, + 20998, + 107, + 20015, + 98, + 33699, + 123, + 20998, + 245, + 165, + 45865, + 164, + 112, + 253, + 164, + 121, + 121, + 161, + 240, + 234, + 165, + 243, + 123, + 33768, + 114, + 29785, + 112, + 21410, + 32573, + 238, + 26193, + 234, + 16764, + 18796, + 109, + 12859, + 236, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 165, + 250, + 222, + 17358, + 223, + 162, + 242, + 107, + 162, + 234, + 223, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 164, + 106, + 123, + 29785, + 106, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 165, + 227, + 235, + 13783, + 229, + 13783, + 248, + 43718, + 116, + 13783, + 226, + 49426, + 228, + 161, + 247, + 101, + 23513, + 32014, + 22522, + 117, + 34932, + 237, + 37863, + 227, + 27764, + 246, + 161, + 240, + 234, + 32014, + 22522, + 117, + 34932, + 237, + 163, + 94, + 105, + 33566, + 246, + 165, + 102, + 109, + 27950, + 101, + 161, + 247, + 101, + 171, + 120, + 234, + 20015, + 98, + 162, + 237, + 238, + 165, + 45865, + 163, + 111, + 119, + 163, + 119, + 253, + 21410, + 32573, + 238, + 26193, + 234, + 34460, + 253, + 41753, + 99, + 161, + 240, + 234, + 163, + 101, + 111, + 22522, + 248, + 45250, + 100, + 16764, + 10545, + 222, + 119, + 45298, + 171, + 120, + 234, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 161, + 240, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 45298, + 29785, + 112, + 21410, + 10310, + 119, + 17358, + 223, + 44293, + 118, + 26344, + 104, + 28839, + 101, + 12859, + 236, + 22522, + 225, + 20015, + 105, + 21410, + 18796, + 101, + 34460, + 242, + 23513, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 161, + 240, + 234, + 45250, + 100, + 47797, + 121, + 16764, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 18796, + 101, + 12859, + 236, + 10310, + 103, + 21689, + 45635, + 18796, + 101, + 171, + 120, + 234, + 32003, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 18796, + 101, + 12859, + 236, + 162, + 242, + 107, + 162, + 234, + 223, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 164, + 106, + 123, + 29785, + 106, + 16764, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 34460, + 248, + 30585, + 116, + 162, + 107, + 242, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 162, + 249, + 112, + 165, + 45865, + 171, + 120, + 234, + 20015, + 98, + 46479, + 251, + 46237, + 223, + 163, + 111, + 119, + 163, + 119, + 253, + 21410, + 45250, + 100, + 47797, + 121, + 161, + 240, + 234, + 163, + 101, + 111, + 22522, + 248, + 45250, + 100, + 16764, + 27, + 91, + 320, + 62, + 437, + 91, + 29, + 50256 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 464, + 1115, + 4165, + 7577, + 389, + 2266, + 11, + 4171, + 11, + 290, + 7872, + 29847, + 91, + 320, + 62, + 437, + 91, + 29, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 161, + 240, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 42468, + 10310, + 97, + 163, + 100, + 235, + 38834, + 28938, + 234, + 163, + 109, + 119, + 161, + 252, + 233, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 163, + 111, + 119, + 163, + 119, + 253, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 21410, + 10310, + 119, + 17358, + 223, + 44293, + 118, + 26344, + 104, + 28839, + 101, + 12859, + 236, + 18796, + 101, + 34460, + 242, + 23513, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 161, + 240, + 234, + 45250, + 100, + 47797, + 121, + 16764, + 220, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 171, + 120, + 234, + 165, + 94, + 122, + 28938, + 235, + 45250, + 251, + 20046, + 231, + 171, + 120, + 234, + 42468, + 10310, + 118, + 10310, + 103, + 21689, + 45635, + 18796, + 101, + 32003, + 234, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 16764, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 18796, + 101, + 12859, + 236, + 33768, + 98, + 30585, + 116, + 21410, + 32432, + 98, + 43291, + 23513, + 161, + 101, + 109, + 20046, + 238, + 161, + 240, + 234, + 27764, + 99, + 20046, + 254, + 171, + 120, + 234, + 20998, + 107, + 20015, + 98, + 32573, + 238, + 26193, + 234, + 28938, + 226, + 163, + 100, + 235, + 28938, + 226, + 43718, + 115, + 21410, + 41753, + 242, + 18796, + 101, + 163, + 101, + 233, + 41753, + 237, + 161, + 240, + 234, + 162, + 116, + 116, + 22755, + 237, + 16764, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 31660, + 48958, + 105, + 42468, + 162, + 234, + 231, + 163, + 227, + 100, + 43718, + 229, + 49035, + 228, + 165, + 227, + 235, + 163, + 121, + 106, + 30266, + 98, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 171, + 120, + 234, + 38834, + 32573, + 229, + 20046, + 253, + 20998, + 107, + 20015, + 98, + 43718, + 117, + 162, + 235, + 106, + 10310, + 103, + 21689, + 165, + 250, + 222, + 162, + 109, + 224, + 32573, + 249, + 26193, + 234, + 22522, + 248, + 26344, + 114, + 16764, + 5525, + 222, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 42468, + 10310, + 118, + 12859, + 228, + 162, + 119, + 94, + 164, + 114, + 111, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 165, + 250, + 222, + 162, + 109, + 224, + 32003, + 234, + 164, + 106, + 122, + 164, + 106, + 94, + 21410, + 164, + 106, + 94, + 163, + 106, + 245, + 17312, + 118, + 163, + 111, + 119, + 163, + 119, + 253, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 18796, + 101, + 12859, + 236, + 10310, + 118, + 18796, + 101, + 22755, + 115, + 162, + 237, + 238, + 160, + 122, + 249, + 28938, + 226, + 163, + 100, + 235, + 163, + 121, + 239, + 163, + 119, + 250, + 17312, + 235, + 27950, + 94, + 171, + 120, + 234, + 36685, + 224, + 163, + 121, + 239, + 44165, + 247, + 23513, + 18796, + 113, + 36310, + 165, + 224, + 106, + 20015, + 114, + 161, + 240, + 234, + 23877, + 229, + 20015, + 114, + 27670, + 254, + 164, + 122, + 241, + 163, + 255, + 231, + 16764, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 34460, + 248, + 30585, + 116, + 165, + 250, + 222, + 17358, + 223, + 165, + 45865, + 45250, + 100, + 47797, + 121, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 171, + 120, + 234, + 33176, + 114, + 10310, + 242, + 20998, + 107, + 20015, + 98, + 33699, + 123, + 20998, + 245, + 165, + 45865, + 164, + 112, + 253, + 164, + 121, + 121, + 161, + 240, + 234, + 165, + 243, + 123, + 33768, + 114, + 29785, + 112, + 21410, + 32573, + 238, + 26193, + 234, + 16764, + 18796, + 109, + 12859, + 236, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 165, + 250, + 222, + 17358, + 223, + 162, + 242, + 107, + 162, + 234, + 223, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 164, + 106, + 123, + 29785, + 106, + 171, + 120, + 234, + 22522, + 225, + 20015, + 105, + 34460, + 248, + 30585, + 116, + 165, + 227, + 235, + 13783, + 229, + 13783, + 248, + 43718, + 116, + 13783, + 226, + 49426, + 228, + 161, + 247, + 101, + 23513, + 32014, + 22522, + 117, + 34932, + 237, + 37863, + 227, + 27764, + 246, + 161, + 240, + 234, + 32014, + 22522, + 117, + 34932, + 237, + 163, + 94, + 105, + 33566, + 246, + 165, + 102, + 109, + 27950, + 101, + 161, + 247, + 101, + 171, + 120, + 234, + 20015, + 98, + 162, + 237, + 238, + 165, + 45865, + 163, + 111, + 119, + 163, + 119, + 253, + 21410, + 32573, + 238, + 26193, + 234, + 34460, + 253, + 41753, + 99, + 161, + 240, + 234, + 163, + 101, + 111, + 22522, + 248, + 45250, + 100, + 16764, + 10545, + 222, + 119, + 45298, + 171, + 120, + 234, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 161, + 240, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 45298, + 29785, + 112, + 21410, + 10310, + 119, + 17358, + 223, + 44293, + 118, + 26344, + 104, + 28839, + 101, + 12859, + 236, + 22522, + 225, + 20015, + 105, + 21410, + 18796, + 101, + 34460, + 242, + 23513, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 161, + 240, + 234, + 45250, + 100, + 47797, + 121, + 16764, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 18796, + 101, + 12859, + 236, + 10310, + 103, + 21689, + 45635, + 18796, + 101, + 171, + 120, + 234, + 32003, + 234, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 18796, + 101, + 12859, + 236, + 162, + 242, + 107, + 162, + 234, + 223, + 32014, + 34932, + 237, + 18796, + 101, + 22755, + 115, + 21410, + 164, + 106, + 123, + 29785, + 106, + 16764, + 17312, + 235, + 27950, + 94, + 161, + 247, + 101, + 21410, + 163, + 94, + 105, + 20015, + 114, + 165, + 227, + 235, + 163, + 121, + 106, + 34460, + 248, + 30585, + 116, + 162, + 107, + 242, + 10310, + 103, + 21689, + 18796, + 113, + 164, + 226, + 239, + 162, + 249, + 112, + 165, + 45865, + 171, + 120, + 234, + 20015, + 98, + 46479, + 251, + 46237, + 223, + 163, + 111, + 119, + 163, + 119, + 253, + 21410, + 45250, + 100, + 47797, + 121, + 161, + 240, + 234, + 163, + 101, + 111, + 22522, + 248, + 45250, + 100, + 16764, + 27, + 91, + 320, + 62, + 437, + 91, + 29, + 50256 + ], + "inputs_decode": "<|endoftext|><|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\nWhat are the three primary colors?<|im_end|>\n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|>\n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|>\n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|><|endoftext|>", + "labels_decode": [ + "The three primary colors are red, blue, and yellow.<|im_end|>", + "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|><|endoftext|>" + ], + "seq_length": 998, + "seq_category": "None" +} \ No newline at end of file diff --git a/applications/ColossalChat/tests/test_data/chat_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json b/applications/ColossalChat/tests/test_data/chat_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json new file mode 100644 index 000000000..f1979eb52 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/chat_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json @@ -0,0 +1,919 @@ +{ + "input_ids": [ + 1, + 733, + 16289, + 28793, + 28705, + 1824, + 460, + 272, + 1712, + 6258, + 9304, + 28804, + 28705, + 733, + 28748, + 16289, + 28793, + 415, + 1712, + 6258, + 9304, + 460, + 2760, + 28725, + 5045, + 28725, + 304, + 9684, + 28723, + 2, + 733, + 16289, + 28793, + 259, + 29386, + 30334, + 28995, + 29086, + 29742, + 31753, + 29131, + 29430, + 29224, + 29180, + 29332, + 29117, + 28914, + 29209, + 29705, + 28944, + 28705, + 733, + 28748, + 16289, + 28793, + 28705, + 28995, + 29086, + 29742, + 31753, + 29131, + 29430, + 29224, + 29180, + 28971, + 29745, + 29824, + 28988, + 29338, + 29066, + 29173, + 28914, + 29382, + 29481, + 29363, + 29401, + 29531, + 28924, + 29928, + 29550, + 28914, + 29345, + 29059, + 29209, + 29705, + 29010, + 29160, + 28963, + 31787, + 29041, + 31483, + 29011, + 29253, + 29021, + 29131, + 29261, + 29084, + 28944, + 28705, + 28995, + 29086, + 29742, + 31753, + 28924, + 236, + 164, + 193, + 29046, + 30450, + 29476, + 28924, + 28971, + 29003, + 28995, + 29086, + 29154, + 28963, + 29746, + 29081, + 29382, + 28914, + 29382, + 29481, + 29363, + 28944, + 29928, + 29550, + 29217, + 29408, + 28963, + 29160, + 29142, + 29408, + 28914, + 29487, + 29089, + 29041, + 232, + 171, + 180, + 30488, + 29131, + 29500, + 30978, + 28924, + 29052, + 29074, + 29798, + 29037, + 30344, + 29824, + 30344, + 29675, + 28914, + 29298, + 28963, + 29265, + 29313, + 29131, + 30239, + 30727, + 28944, + 28995, + 29086, + 29742, + 31753, + 28914, + 31483, + 29011, + 29253, + 29021, + 28969, + 30812, + 28971, + 29518, + 30131, + 29144, + 30168, + 29253, + 29021, + 29263, + 29081, + 29382, + 28914, + 28924, + 28988, + 29202, + 29537, + 29052, + 29074, + 29521, + 29020, + 28995, + 29086, + 29259, + 29243, + 29258, + 29037, + 29018, + 29251, + 28944, + 28705, + 29746, + 29430, + 29224, + 29180, + 28971, + 29003, + 29105, + 30493, + 30303, + 29050, + 29195, + 28963, + 29106, + 28914, + 29259, + 29243, + 29746, + 29081, + 29382, + 28914, + 29382, + 29481, + 29363, + 29401, + 29531, + 28924, + 29928, + 29550, + 29217, + 29408, + 28963, + 29160, + 29003, + 28963, + 29106, + 29279, + 29954, + 30344, + 29824, + 29533, + 30229, + 29430, + 29224, + 28924, + 29118, + 29533, + 29992, + 29041, + 29742, + 29169, + 30182, + 29011, + 29131, + 29019, + 29011, + 29359, + 29226, + 29414, + 28944, + 29430, + 29224, + 29180, + 29217, + 29408, + 29259, + 29059, + 29366, + 29261, + 29084, + 28914, + 31483, + 29011, + 29253, + 29021, + 28924, + 29457, + 29958, + 29052, + 29074, + 30761, + 30154, + 29366, + 30685, + 29435, + 29131, + 29495, + 29007, + 29117, + 28914, + 29798, + 29037, + 28944, + 29590, + 29160, + 29430, + 29224, + 29180, + 29259, + 29059, + 29428, + 29569, + 29050, + 29195, + 28963, + 29106, + 28914, + 30164, + 29562, + 28924, + 29928, + 29550, + 29217, + 29408, + 29253, + 29613, + 29292, + 30080, + 29326, + 29093, + 29180, + 29041, + 29050, + 29329, + 29195, + 29188, + 29079, + 29131, + 29050, + 29329, + 29195, + 31483, + 30358, + 236, + 172, + 180, + 29129, + 29180, + 28924, + 29074, + 29279, + 29366, + 29401, + 29531, + 28914, + 29798, + 29037, + 30029, + 29184, + 29131, + 234, + 171, + 182, + 29018, + 29261, + 28944, + 28705, + 29758, + 29332, + 28924, + 28995, + 29086, + 29742, + 31753, + 29131, + 29430, + 29224, + 29180, + 29332, + 29117, + 28914, + 29345, + 29059, + 29209, + 29705, + 29010, + 29160, + 29928, + 29550, + 28914, + 28963, + 31787, + 29041, + 31483, + 29011, + 29253, + 29021, + 29131, + 29261, + 29084, + 28944, + 28995, + 29086, + 29742, + 31753, + 28963, + 29160, + 28995, + 29086, + 29154, + 28963, + 28924, + 29746, + 29430, + 29224, + 29180, + 28963, + 29160, + 29428, + 29569, + 29050, + 29195, + 28963, + 29106, + 28914, + 30164, + 29562, + 28944, + 29430, + 29224, + 29180, + 28914, + 31483, + 29011, + 29253, + 29021, + 29217, + 29408, + 29605, + 28995, + 29086, + 29742, + 31753, + 29250, + 29366, + 28924, + 29074, + 29321, + 29425, + 29401, + 29531, + 28914, + 29261, + 29084, + 29131, + 234, + 171, + 182, + 29018, + 29261, + 28944, + 2 + ], + "labels": [ + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 415, + 1712, + 6258, + 9304, + 460, + 2760, + 28725, + 5045, + 28725, + 304, + 9684, + 28723, + 2, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + -100, + 28705, + 28995, + 29086, + 29742, + 31753, + 29131, + 29430, + 29224, + 29180, + 28971, + 29745, + 29824, + 28988, + 29338, + 29066, + 29173, + 28914, + 29382, + 29481, + 29363, + 29401, + 29531, + 28924, + 29928, + 29550, + 28914, + 29345, + 29059, + 29209, + 29705, + 29010, + 29160, + 28963, + 31787, + 29041, + 31483, + 29011, + 29253, + 29021, + 29131, + 29261, + 29084, + 28944, + 28705, + 28995, + 29086, + 29742, + 31753, + 28924, + 236, + 164, + 193, + 29046, + 30450, + 29476, + 28924, + 28971, + 29003, + 28995, + 29086, + 29154, + 28963, + 29746, + 29081, + 29382, + 28914, + 29382, + 29481, + 29363, + 28944, + 29928, + 29550, + 29217, + 29408, + 28963, + 29160, + 29142, + 29408, + 28914, + 29487, + 29089, + 29041, + 232, + 171, + 180, + 30488, + 29131, + 29500, + 30978, + 28924, + 29052, + 29074, + 29798, + 29037, + 30344, + 29824, + 30344, + 29675, + 28914, + 29298, + 28963, + 29265, + 29313, + 29131, + 30239, + 30727, + 28944, + 28995, + 29086, + 29742, + 31753, + 28914, + 31483, + 29011, + 29253, + 29021, + 28969, + 30812, + 28971, + 29518, + 30131, + 29144, + 30168, + 29253, + 29021, + 29263, + 29081, + 29382, + 28914, + 28924, + 28988, + 29202, + 29537, + 29052, + 29074, + 29521, + 29020, + 28995, + 29086, + 29259, + 29243, + 29258, + 29037, + 29018, + 29251, + 28944, + 28705, + 29746, + 29430, + 29224, + 29180, + 28971, + 29003, + 29105, + 30493, + 30303, + 29050, + 29195, + 28963, + 29106, + 28914, + 29259, + 29243, + 29746, + 29081, + 29382, + 28914, + 29382, + 29481, + 29363, + 29401, + 29531, + 28924, + 29928, + 29550, + 29217, + 29408, + 28963, + 29160, + 29003, + 28963, + 29106, + 29279, + 29954, + 30344, + 29824, + 29533, + 30229, + 29430, + 29224, + 28924, + 29118, + 29533, + 29992, + 29041, + 29742, + 29169, + 30182, + 29011, + 29131, + 29019, + 29011, + 29359, + 29226, + 29414, + 28944, + 29430, + 29224, + 29180, + 29217, + 29408, + 29259, + 29059, + 29366, + 29261, + 29084, + 28914, + 31483, + 29011, + 29253, + 29021, + 28924, + 29457, + 29958, + 29052, + 29074, + 30761, + 30154, + 29366, + 30685, + 29435, + 29131, + 29495, + 29007, + 29117, + 28914, + 29798, + 29037, + 28944, + 29590, + 29160, + 29430, + 29224, + 29180, + 29259, + 29059, + 29428, + 29569, + 29050, + 29195, + 28963, + 29106, + 28914, + 30164, + 29562, + 28924, + 29928, + 29550, + 29217, + 29408, + 29253, + 29613, + 29292, + 30080, + 29326, + 29093, + 29180, + 29041, + 29050, + 29329, + 29195, + 29188, + 29079, + 29131, + 29050, + 29329, + 29195, + 31483, + 30358, + 236, + 172, + 180, + 29129, + 29180, + 28924, + 29074, + 29279, + 29366, + 29401, + 29531, + 28914, + 29798, + 29037, + 30029, + 29184, + 29131, + 234, + 171, + 182, + 29018, + 29261, + 28944, + 28705, + 29758, + 29332, + 28924, + 28995, + 29086, + 29742, + 31753, + 29131, + 29430, + 29224, + 29180, + 29332, + 29117, + 28914, + 29345, + 29059, + 29209, + 29705, + 29010, + 29160, + 29928, + 29550, + 28914, + 28963, + 31787, + 29041, + 31483, + 29011, + 29253, + 29021, + 29131, + 29261, + 29084, + 28944, + 28995, + 29086, + 29742, + 31753, + 28963, + 29160, + 28995, + 29086, + 29154, + 28963, + 28924, + 29746, + 29430, + 29224, + 29180, + 28963, + 29160, + 29428, + 29569, + 29050, + 29195, + 28963, + 29106, + 28914, + 30164, + 29562, + 28944, + 29430, + 29224, + 29180, + 28914, + 31483, + 29011, + 29253, + 29021, + 29217, + 29408, + 29605, + 28995, + 29086, + 29742, + 31753, + 29250, + 29366, + 28924, + 29074, + 29321, + 29425, + 29401, + 29531, + 28914, + 29261, + 29084, + 29131, + 234, + 171, + 182, + 29018, + 29261, + 28944, + 2 + ], + "inputs_decode": " [INST] What are the three primary colors? [/INST] The three primary colors are red, blue, and yellow. [INST] 解释个人电脑和服务器之间的区别。 [/INST] 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。", + "labels_decode": [ + " The three primary colors are red, blue, and yellow.", + " 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。" + ], + "seq_length": 453, + "seq_category": "None" +} \ No newline at end of file