upgrade colossal-chat support tp_group>1, add sp for sft

pull/5759/head
YeAnbang 2024-05-27 05:55:57 +00:00
parent 73e88a5553
commit 7a7e86987d
33 changed files with 7574 additions and 105 deletions

View File

@ -17,6 +17,7 @@ class Conversation:
system_message: str
chat_template: str
stop_ids: List[int]
end_of_assistant: str
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@ -24,7 +25,7 @@ class Conversation:
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"])
conv.clear()
return conv
@ -109,6 +110,8 @@ def setup_conversation_template(
"""
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
if "end_of_assistant" not in chat_template_config:
raise ValueError("Please set the end of assistant token.")
if "system_message" not in chat_template_config:
logger.warning("No system message is provided, will not use system message.")
if "chat_template" not in chat_template_config:

View File

@ -248,9 +248,9 @@ class StatefulDistributedSampler(DistributedSampler):
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
use_tp: Optional[bool] = False,
tp_size: int = 1,
) -> None:
if not use_tp:
if not tp_size>1:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
@ -261,14 +261,16 @@ class StatefulDistributedSampler(DistributedSampler):
)
else:
# adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
# TODO: support tp_group>1. will fix it later
num_replicas = 1
if rank is None:
rank = dist.get_rank()
world_size = dist.get_world_size()
dp_size = world_size // tp_size # data parallel size
dp_rank = int(rank / tp_size) # data parallel rank
if rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
self.dataset = dataset
self.num_replicas = num_replicas
self.dp_rank = dp_rank
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
@ -287,10 +289,10 @@ class StatefulDistributedSampler(DistributedSampler):
self.shuffle = shuffle
self.seed = seed
self.start_index = 0
self.use_tp = use_tp
self.tp_size = tp_size
def __iter__(self) -> Iterator:
if self.use_tp:
if self.tp_size > 1:
# TODO Add support for tp_group not equal to 1
pass
# adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
@ -316,10 +318,9 @@ class StatefulDistributedSampler(DistributedSampler):
# subsample
indices = indices[
: self.total_size : self.num_replicas
self.dp_rank: self.dp_rank + self.total_size : self.num_replicas
] # num_replicas=tp_group=1, we only support tp_group==1 for now
assert len(indices) == self.num_samples
return iter(indices)
else:
@ -345,7 +346,7 @@ def setup_distributed_dataloader(
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
use_tp: Optional[bool] = False,
tp_size: Optional[int] = 1,
**kwargs,
) -> DataLoader:
"""
@ -353,14 +354,16 @@ def setup_distributed_dataloader(
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
# world_size = tp_size * pp_size
assert process_group.size()%tp_size == 0, f"process_group.size()={process_group.size()} must be divisible by tp_size={tp_size}"
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size() if not use_tp else 1,
num_replicas=int(process_group.size()/tp_size),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
use_tp=use_tp,
tp_size=tp_size,
)
# Deterministic dataloader

View File

@ -55,6 +55,8 @@ def supervised_tokenize_sft(
for mess in messages:
from_str = mess["from"]
if from_str is None:
print(mess)
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
@ -95,17 +97,26 @@ def supervised_tokenize_sft(
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
conversation_template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
label_decode = []
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start : end + 1] = tokenized[start : end + 1]
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
labels[start : end] = tokenized[start : end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
@ -123,6 +134,20 @@ def supervised_tokenize_sft(
# For some model without bos/eos may raise the following errors
try:
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start!=end:
label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False))
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
@ -191,7 +216,9 @@ def tokenize_prompt_dataset(
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt,
conversation_template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
@ -219,7 +246,8 @@ def apply_rlhf_data_format(
):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
tempalte.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
@ -232,8 +260,8 @@ def apply_rlhf_data_format(
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1])
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
loss_mask[start : end] = [1] * len(loss_mask[start : end])
label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized

View File

@ -113,20 +113,23 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
return input_ids, loss_starts, loss_ends
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str, end_of_assistant: str):
# Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
start_idx = 0
chunks = []
require_loss = []
for line in messages:
content_length = len(line["content"])
first_occur = prompt.find(line["content"], start_idx)
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length:]:
content_length = prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + len(line["content"])])
chunks.append(prompt[first_occur : first_occur + content_length])
else:
chunks.append(prompt[start_idx : first_occur - 1])
chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
start_idx = first_occur + len(line["content"])
chunks.append(prompt[first_occur - 1 : first_occur + content_length])
start_idx = first_occur + content_length
if line["role"].lower() == "assistant":
require_loss.append(False)
require_loss.append(True)

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
7
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -3,5 +3,6 @@
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
],
"end_of_assistant": "<|endoftext|>"
}

View File

@ -3,5 +3,6 @@
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
],
"end_of_assistant": "<|user|>"
}

View File

@ -1,7 +0,0 @@
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}

View File

@ -2,6 +2,7 @@
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
null
]
}
2
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -3,5 +3,6 @@
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
],
"end_of_assistant": "</s>"
}

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
100001
],
"end_of_assistant": "<end▁of▁sentence>"
}

View File

@ -3,5 +3,6 @@
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}
],
"end_of_assistant": "</s>"
}

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
50256
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -1,7 +0,0 @@
{
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
2
]
}

View File

@ -338,7 +338,18 @@ In this code we provide a flexible way for users to set the conversation templat
{
"chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
"system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
"stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically,
"end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format,
```
<|im_start|>system
system messages
<|im_end|>
<|im_start|>user
How far is the moon? <|im_end|>
<|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>...
```
the end_of_assistant tokens are "<|im_end|>"
"stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicating the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically
}
```
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),

View File

@ -226,7 +226,7 @@ def main():
"max_length": args.max_length,
},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
num_proc= min(len(dataset), cpu_count()),
)
dataset = dataset.filter(

View File

@ -1,13 +1,22 @@
SAVE_DIR=""
SAVE_DIR="/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf"
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
# python prepare_dataset.py --type sft \
# --data_input_dirs /home/yeanbang/data/experiment/dataset/sft_data/test/sft-data \
# --conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/THUDM_chatglm3-6b.json \
# --tokenizer_dir "/mnt/jfs-hdd/home/data/models/ChatGlm-6B" \
# --data_cache_dir $SAVE_DIR/cache \
# --data_jsonl_output_dir $SAVE_DIR/jsonl \
# --data_arrow_output_dir $SAVE_DIR/arrow \
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_input_dirs /mnt/jfs-hdd/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train \
--conversation_template_config /home/yeanbang/data/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
--tokenizer_dir "/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \

View File

@ -1 +1 @@
10.20.1.82
172.27.183.199

View File

@ -14,13 +14,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
import inspect
import sys
import torch.distributed as dist
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def train(args):
print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
@ -35,6 +42,38 @@ def train(args):
# ==============================
# Initialize Booster
# ==============================
init_ctx = nullcontext()
with init_ctx:
model = AutoModelForCausalLM.from_pretrained(args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model
# try:
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
# policy = get_autopolicy(model)
# if policy is not None:
# if args.plugin in ['zero2', 'zero2_cpu']:
# # if compatible, set the plugin to hybrid, which use colo-attention
# args.plugin = 'hybrid'
# args.zero_stage = 2
# if args.plugin == 'zero2_cpu':
# args.zero_cpu_offload = True
# else:
# args.zero_cpu_offload = False
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
# except NotImplementedError:
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
# if args.use_flash_attn:
# del model
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrain,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# attn_implementation="flash_attention_2",
# trust_remote_code=True
# )
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
@ -47,7 +86,8 @@ def train(args):
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -55,6 +95,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@ -71,11 +112,16 @@ def train(args):
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
elif args.plugin == "hybrid":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@ -93,20 +139,7 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
model.gradient_checkpointing_enable()
@ -131,6 +164,7 @@ def train(args):
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "right"
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
@ -156,8 +190,13 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
# print(len(train_dataloader))
# for batch in train_dataloader:
# print(dist.get_rank(), tokenizer.batch_decode(batch["input_ids"]))
# break
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
@ -255,7 +294,7 @@ def train(args):
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -270,13 +309,18 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
@ -287,7 +331,7 @@ if __name__ == "__main__":
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",

View File

@ -17,22 +17,22 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
# export CUDA_VISIBLE_DEVICES=4,5,6
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="sft"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs
PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path
declare -a dataset=(
YOUR/SFT/DATA/DIR/arrow/part-00000
YOUR/SFT/DATA/DIR/arrow/part-00001
YOUR/SFT/DATA/DIR/arrow/part-00002
YOUR/SFT/DATA/DIR/arrow/part-00003
YOUR/SFT/DATA/DIR/arrow/part-00004
YOUR/SFT/DATA/DIR/arrow/part-00005
YOUR/SFT/DATA/DIR/arrow/part-00006
YOUR/SFT/DATA/DIR/arrow/part-00007
YOUR/SFT/DATA/DIR/arrow/part-00008
YOUR/SFT/DATA/DIR/arrow/part-00009
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@ -40,6 +40,8 @@ FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
echo $(which colossalai)
echo $(which python)
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
@ -50,10 +52,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--config_file $CONFIG_FILE \
--lora_rank 0 \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
--accumulation_steps 1 \
--lr 2e-5 \
--max_len 2048 \
--tp 1 \
--pp 1 \
--zero_stage 2 \
--batch_size 4 \
--max_epochs 3 \
--accumulation_steps 4 \
--lr 5e-5 \
--max_len 400 \
--grad_checkpoint \
--use_wandb
--use_wandb \
--use_flash_attn

View File

@ -1,9 +1,8 @@
transformers==4.34.1
huggingface_hub==0.17.3
transformers>=4.37.0
tqdm
datasets
loralib
colossalai>=0.3.6
colossalai>=0.3.7
torch>=1.12.1
langchain
tokenizers

View File

@ -0,0 +1,38 @@
from coati.dataset import setup_conversation_template
from coati.dataset.conversation import Conversation
from coati.dataset.tokenization_utils import supervised_tokenize_sft
from transformers import AutoTokenizer
import json
import os
model_data_mapping = {
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
'Qwen/Qwen-7B-Chat': 'Qwen_Qwen-7B-Chat.json',
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
'microsoft/phi-2': 'microsoft_phi-2.json',
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
}
chat_template_config_path = '../config/conversation_template'
def test_tokenization_sft():
for model in model_data_mapping:
print(f"#############{model}#############")
conversation_template_config = os.path.join(chat_template_config_path, model_data_mapping[model])
messages = [{"from": "human", "content": "What are the three primary colors?"},
{"from": "assistant", "content": "The three primary colors are red, blue, and yellow."},
{"from": "human", "content": "解释个人电脑和服务器之间的区别。"},
{"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]
chat_template_config = json.load(open(conversation_template_config, "r", encoding="utf8"))
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=chat_template_config, save_path=conversation_template_config
)
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
with open(f"./test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
assert json.dumps(json.load(f)) == json.dumps(output), f"model: {model} failed"

View File

@ -0,0 +1,585 @@
{
"input_ids": [
1,
59603,
9334,
1397,
562,
13310,
2756,
597,
663,
15874,
10357,
14135,
98,
707,
14135,
3641,
6901,
97,
7283,
97,
597,
31081,
8476,
592,
567,
2756,
59610,
59575,
3275,
98,
144,
144,
6,
3903,
144,
5697,
678,
567,
1604,
5789,
7590,
100,
7,
144,
6,
765,
13611,
144,
1263,
1604,
5789,
7590,
678,
2894,
97,
5083,
97,
597,
10744,
98,
7,
144,
6,
3903,
144,
7714,
2897,
11491,
59652,
17504,
7125,
13189,
102,
7,
144,
6,
765,
13611,
144,
2897,
11491,
59652,
17504,
59626,
12295,
2618,
24768,
12780,
2644,
101,
6774,
10495,
13189,
5505,
23337,
105,
17713,
9217,
59652,
9176,
102,
59568,
2897,
11491,
101,
60604,
59867,
60084,
60160,
101,
29874,
2897,
2253,
59732,
23806,
12780,
102,
6774,
8224,
4983,
7304,
5973,
105,
11152,
44353,
101,
1229,
7252,
33282,
32371,
59652,
5186,
102,
2897,
11491,
59599,
17713,
9217,
39485,
4245,
3600,
9217,
59670,
23806,
101,
3438,
6985,
2749,
2897,
4195,
1540,
23525,
102,
59568,
59732,
17504,
15703,
6748,
8262,
4576,
17107,
59732,
23806,
12780,
2644,
101,
6774,
8224,
4983,
59648,
4576,
2479,
4207,
4014,
1893,
101,
59738,
8276,
105,
5601,
24111,
59652,
6280,
18192,
59748,
102,
17504,
8224,
1867,
57898,
59599,
17713,
9217,
101,
4502,
1229,
17456,
59719,
38298,
59652,
59509,
7252,
102,
3574,
17504,
1867,
3392,
8262,
28830,
15595,
101,
6774,
8224,
18802,
59691,
60326,
25829,
105,
59647,
22981,
25145,
24724,
22981,
43452,
11553,
60101,
101,
59659,
3206,
13890,
7252,
7042,
59652,
26815,
102,
59568,
20237,
101,
2897,
11491,
59652,
17504,
7125,
2461,
13189,
5505,
30281,
23337,
105,
17713,
9217,
59652,
9176,
102,
2897,
11491,
4983,
2897,
2253,
101,
59732,
17504,
4983,
3392,
8262,
28830,
15595,
102,
1893,
12507,
17713,
9217,
8224,
59806,
2897,
11491,
16549,
101,
59659,
5731,
13890,
9176,
59652,
26815,
102,
7
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
1263,
1604,
5789,
7590,
678,
2894,
97,
5083,
97,
597,
10744,
98,
7,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
2897,
11491,
59652,
17504,
59626,
12295,
2618,
24768,
12780,
2644,
101,
6774,
10495,
13189,
5505,
23337,
105,
17713,
9217,
59652,
9176,
102,
59568,
2897,
11491,
101,
60604,
59867,
60084,
60160,
101,
29874,
2897,
2253,
59732,
23806,
12780,
102,
6774,
8224,
4983,
7304,
5973,
105,
11152,
44353,
101,
1229,
7252,
33282,
32371,
59652,
5186,
102,
2897,
11491,
59599,
17713,
9217,
39485,
4245,
3600,
9217,
59670,
23806,
101,
3438,
6985,
2749,
2897,
4195,
1540,
23525,
102,
59568,
59732,
17504,
15703,
6748,
8262,
4576,
17107,
59732,
23806,
12780,
2644,
101,
6774,
8224,
4983,
59648,
4576,
2479,
4207,
4014,
1893,
101,
59738,
8276,
105,
5601,
24111,
59652,
6280,
18192,
59748,
102,
17504,
8224,
1867,
57898,
59599,
17713,
9217,
101,
4502,
1229,
17456,
59719,
38298,
59652,
59509,
7252,
102,
3574,
17504,
1867,
3392,
8262,
28830,
15595,
101,
6774,
8224,
18802,
59691,
60326,
25829,
105,
59647,
22981,
25145,
24724,
22981,
43452,
11553,
60101,
101,
59659,
3206,
13890,
7252,
7042,
59652,
26815,
102,
59568,
20237,
101,
2897,
11491,
59652,
17504,
7125,
2461,
13189,
5505,
30281,
23337,
105,
17713,
9217,
59652,
9176,
102,
2897,
11491,
4983,
2897,
2253,
101,
59732,
17504,
4983,
3392,
8262,
28830,
15595,
102,
1893,
12507,
17713,
9217,
8224,
59806,
2897,
11491,
16549,
101,
59659,
5731,
13890,
9176,
59652,
26815,
102,
7
],
"inputs_decode": "<|startoftext|> A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_start|>user\nWhat are the three primary colors?<|im_end|> \n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|> \n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|> \n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>",
"labels_decode": [
" The three primary colors are red, blue, and yellow.<|im_end|>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>"
],
"seq_length": 286,
"seq_category": "None"
}

View File

@ -0,0 +1,607 @@
{
"input_ids": [
1,
59653,
34448,
59651,
7488,
55701,
4422,
144,
59603,
9334,
1397,
562,
13310,
2756,
597,
663,
15874,
10357,
14135,
98,
707,
14135,
3641,
6901,
97,
7283,
97,
597,
31081,
8476,
592,
567,
2756,
59610,
59575,
3275,
98,
144,
144,
144,
59666,
1359,
55701,
4422,
144,
144,
5697,
678,
567,
1604,
5789,
7590,
100,
1273,
59598,
34448,
59651,
707,
1604,
5789,
7590,
678,
2894,
97,
5083,
97,
597,
10744,
98,
59568,
2,
1,
59653,
34448,
59651,
59568,
7714,
2897,
11491,
59652,
17504,
7125,
13189,
102,
1273,
59598,
34448,
59651,
59568,
2897,
11491,
59652,
17504,
59626,
12295,
2618,
24768,
12780,
2644,
101,
6774,
10495,
13189,
5505,
23337,
105,
17713,
9217,
59652,
9176,
102,
59568,
2897,
11491,
101,
60604,
59867,
60084,
60160,
101,
29874,
2897,
2253,
59732,
23806,
12780,
102,
6774,
8224,
4983,
7304,
5973,
105,
11152,
44353,
101,
1229,
7252,
33282,
32371,
59652,
5186,
102,
2897,
11491,
59599,
17713,
9217,
39485,
4245,
3600,
9217,
59670,
23806,
101,
3438,
6985,
2749,
2897,
4195,
1540,
23525,
102,
59568,
59732,
17504,
15703,
6748,
8262,
4576,
17107,
59732,
23806,
12780,
2644,
101,
6774,
8224,
4983,
59648,
4576,
2479,
4207,
4014,
1893,
101,
59738,
8276,
105,
5601,
24111,
59652,
6280,
18192,
59748,
102,
17504,
8224,
1867,
57898,
59599,
17713,
9217,
101,
4502,
1229,
17456,
59719,
38298,
59652,
59509,
7252,
102,
3574,
17504,
1867,
3392,
8262,
28830,
15595,
101,
6774,
8224,
18802,
59691,
60326,
25829,
105,
59647,
22981,
25145,
24724,
22981,
43452,
11553,
60101,
101,
59659,
3206,
13890,
7252,
7042,
59652,
26815,
102,
59568,
20237,
101,
2897,
11491,
59652,
17504,
7125,
2461,
13189,
5505,
30281,
23337,
105,
17713,
9217,
59652,
9176,
102,
2897,
11491,
4983,
2897,
2253,
101,
59732,
17504,
4983,
3392,
8262,
28830,
15595,
102,
1893,
12507,
17713,
9217,
8224,
59806,
2897,
11491,
16549,
101,
59659,
5731,
13890,
9176,
59652,
26815,
102,
59568,
2
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
707,
1604,
5789,
7590,
678,
2894,
97,
5083,
97,
597,
10744,
98,
59568,
2,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
59568,
2897,
11491,
59652,
17504,
59626,
12295,
2618,
24768,
12780,
2644,
101,
6774,
10495,
13189,
5505,
23337,
105,
17713,
9217,
59652,
9176,
102,
59568,
2897,
11491,
101,
60604,
59867,
60084,
60160,
101,
29874,
2897,
2253,
59732,
23806,
12780,
102,
6774,
8224,
4983,
7304,
5973,
105,
11152,
44353,
101,
1229,
7252,
33282,
32371,
59652,
5186,
102,
2897,
11491,
59599,
17713,
9217,
39485,
4245,
3600,
9217,
59670,
23806,
101,
3438,
6985,
2749,
2897,
4195,
1540,
23525,
102,
59568,
59732,
17504,
15703,
6748,
8262,
4576,
17107,
59732,
23806,
12780,
2644,
101,
6774,
8224,
4983,
59648,
4576,
2479,
4207,
4014,
1893,
101,
59738,
8276,
105,
5601,
24111,
59652,
6280,
18192,
59748,
102,
17504,
8224,
1867,
57898,
59599,
17713,
9217,
101,
4502,
1229,
17456,
59719,
38298,
59652,
59509,
7252,
102,
3574,
17504,
1867,
3392,
8262,
28830,
15595,
101,
6774,
8224,
18802,
59691,
60326,
25829,
105,
59647,
22981,
25145,
24724,
22981,
43452,
11553,
60101,
101,
59659,
3206,
13890,
7252,
7042,
59652,
26815,
102,
59568,
20237,
101,
2897,
11491,
59652,
17504,
7125,
2461,
13189,
5505,
30281,
23337,
105,
17713,
9217,
59652,
9176,
102,
2897,
11491,
4983,
2897,
2253,
101,
59732,
17504,
4983,
3392,
8262,
28830,
15595,
102,
1893,
12507,
17713,
9217,
8224,
59806,
2897,
11491,
16549,
101,
59659,
5731,
13890,
9176,
59652,
26815,
102,
59568,
2
],
"inputs_decode": "<|startoftext|> [INST] <<SYS>>\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n\n<</SYS>>\n\nWhat are the three primary colors? [/INST] The three primary colors are red, blue, and yellow. <|endoftext|><|startoftext|> [INST] 解释个人电脑和服务器之间的区别。 [/INST] 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。 <|endoftext|>",
"labels_decode": [
" The three primary colors are red, blue, and yellow. <|endoftext|>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。 <|endoftext|>"
],
"seq_length": 297,
"seq_category": "None"
}

View File

@ -0,0 +1,603 @@
{
"input_ids": [
151644,
8948,
198,
32,
6236,
1948,
264,
22208,
3738,
323,
458,
20443,
11229,
17847,
13,
576,
17847,
6696,
10950,
11,
11682,
11,
323,
47787,
11253,
311,
279,
3738,
594,
4755,
382,
151645,
198,
151644,
872,
198,
3838,
525,
279,
2326,
6028,
7987,
30,
151645,
198,
151644,
77091,
198,
785,
2326,
6028,
7987,
525,
2518,
11,
6303,
11,
323,
13753,
13,
151645,
198,
151644,
872,
198,
104136,
99605,
104145,
33108,
89047,
104186,
102665,
1773,
151645,
198,
151644,
77091,
198,
99605,
104145,
33108,
89047,
20412,
101441,
99604,
109963,
104564,
72448,
3837,
104017,
104396,
102665,
101321,
105795,
5373,
105433,
85767,
33108,
102111,
1773,
220,
99605,
104145,
3837,
99846,
13072,
90663,
64559,
3837,
20412,
17714,
99605,
37029,
68536,
70500,
9370,
104564,
1773,
104017,
102119,
100751,
101254,
104066,
5373,
100415,
33108,
100134,
3837,
73670,
104001,
100646,
99200,
100535,
113384,
33108,
99329,
1773,
99605,
104145,
9370,
105433,
85767,
111071,
101892,
100142,
85767,
36407,
70500,
9370,
3837,
100632,
104047,
100345,
99605,
100354,
71817,
104790,
1773,
8908,
222,
234,
89047,
104802,
101929,
100722,
20002,
104378,
68536,
70500,
9370,
104564,
72448,
3837,
104017,
102119,
100751,
17714,
110782,
100646,
71356,
47874,
3837,
29524,
100010,
5373,
116617,
33108,
26898,
107468,
49567,
1773,
89047,
102119,
85106,
113313,
9370,
105433,
85767,
90395,
100136,
73670,
105581,
44636,
118878,
33108,
102612,
9370,
104001,
1773,
101887,
89047,
85106,
100143,
100722,
107494,
104925,
3837,
104017,
102119,
102578,
42140,
71137,
111249,
5373,
26288,
106656,
111390,
33108,
26288,
106656,
114897,
102474,
31548,
3837,
23031,
100627,
105743,
104001,
101149,
33108,
108239,
1773,
90476,
119,
53930,
3837,
99605,
104145,
33108,
89047,
104186,
99558,
102665,
101321,
104017,
9370,
105795,
5373,
105433,
85767,
33108,
102111,
1773,
99605,
104145,
100751,
99605,
37029,
3837,
68536,
89047,
100751,
100143,
100722,
107494,
104925,
1773,
89047,
9370,
105433,
85767,
102119,
56006,
99605,
104145,
105377,
3837,
23031,
101907,
105743,
102111,
33108,
108239,
1773,
151645
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
785,
2326,
6028,
7987,
525,
2518,
11,
6303,
11,
323,
13753,
13,
151645,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
99605,
104145,
33108,
89047,
20412,
101441,
99604,
109963,
104564,
72448,
3837,
104017,
104396,
102665,
101321,
105795,
5373,
105433,
85767,
33108,
102111,
1773,
220,
99605,
104145,
3837,
99846,
13072,
90663,
64559,
3837,
20412,
17714,
99605,
37029,
68536,
70500,
9370,
104564,
1773,
104017,
102119,
100751,
101254,
104066,
5373,
100415,
33108,
100134,
3837,
73670,
104001,
100646,
99200,
100535,
113384,
33108,
99329,
1773,
99605,
104145,
9370,
105433,
85767,
111071,
101892,
100142,
85767,
36407,
70500,
9370,
3837,
100632,
104047,
100345,
99605,
100354,
71817,
104790,
1773,
8908,
222,
234,
89047,
104802,
101929,
100722,
20002,
104378,
68536,
70500,
9370,
104564,
72448,
3837,
104017,
102119,
100751,
17714,
110782,
100646,
71356,
47874,
3837,
29524,
100010,
5373,
116617,
33108,
26898,
107468,
49567,
1773,
89047,
102119,
85106,
113313,
9370,
105433,
85767,
90395,
100136,
73670,
105581,
44636,
118878,
33108,
102612,
9370,
104001,
1773,
101887,
89047,
85106,
100143,
100722,
107494,
104925,
3837,
104017,
102119,
102578,
42140,
71137,
111249,
5373,
26288,
106656,
111390,
33108,
26288,
106656,
114897,
102474,
31548,
3837,
23031,
100627,
105743,
104001,
101149,
33108,
108239,
1773,
90476,
119,
53930,
3837,
99605,
104145,
33108,
89047,
104186,
99558,
102665,
101321,
104017,
9370,
105795,
5373,
105433,
85767,
33108,
102111,
1773,
99605,
104145,
100751,
99605,
37029,
3837,
68536,
89047,
100751,
100143,
100722,
107494,
104925,
1773,
89047,
9370,
105433,
85767,
102119,
56006,
99605,
104145,
105377,
3837,
23031,
101907,
105743,
102111,
33108,
108239,
1773,
151645
],
"inputs_decode": "<|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\nWhat are the three primary colors?<|im_end|>\n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|>\n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|>\n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>",
"labels_decode": [
"The three primary colors are red, blue, and yellow.<|im_end|>",
"个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>"
],
"seq_length": 295,
"seq_category": "None"
}

View File

@ -0,0 +1,715 @@
{
"input_ids": [
906,
31007,
326,
30962,
6631,
31007,
30994,
13361,
13,
30938,
8911,
1074,
260,
10976,
1869,
293,
284,
12060,
7151,
9319,
30930,
353,
9319,
3510,
6483,
30932,
6374,
30932,
293,
25712,
7115,
289,
267,
1869,
30953,
30917,
2554,
30930,
13,
13,
31002,
31007,
326,
30962,
437,
31007,
30994,
13,
31002,
31007,
326,
30962,
6631,
31007,
30994,
4865,
13,
1266,
383,
267,
1194,
4390,
7129,
30987,
906,
31007,
326,
30962,
437,
31007,
30994,
13,
31002,
31007,
326,
30962,
6631,
31007,
30994,
530,
18971,
13,
353,
1194,
4390,
7129,
383,
2603,
30932,
4610,
30932,
293,
8362,
30930,
31002,
31007,
326,
30962,
437,
31007,
30994,
30910,
13,
31002,
31007,
326,
30962,
6631,
31007,
30994,
4865,
13,
30910,
32929,
31745,
34032,
37154,
55083,
32695,
34213,
31155,
906,
31007,
326,
30962,
437,
31007,
30994,
13,
31002,
31007,
326,
30962,
6631,
31007,
30994,
530,
18971,
13,
30910,
31745,
34032,
37154,
55083,
54532,
33701,
31723,
39141,
33606,
31739,
31123,
32542,
33248,
34213,
33260,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
30910,
31745,
34032,
31123,
55502,
54653,
54872,
54923,
31123,
38391,
31745,
31695,
54617,
37538,
33606,
31155,
32542,
32955,
32419,
32432,
32249,
31201,
32645,
44505,
31123,
31628,
32557,
41054,
43569,
54542,
32033,
31155,
31745,
34032,
54530,
36073,
33339,
47907,
32001,
31858,
33339,
54556,
37538,
31123,
31925,
32591,
31793,
31745,
31987,
31636,
36058,
31155,
43833,
38107,
35231,
32454,
32325,
32053,
35299,
54617,
37538,
33606,
31739,
31123,
32542,
32955,
32419,
54541,
49193,
31986,
31863,
31645,
31123,
54627,
32438,
31201,
48747,
54542,
32410,
38520,
54609,
31155,
38107,
32955,
31665,
46294,
54530,
36073,
33339,
31123,
32187,
31628,
35757,
54589,
54376,
54542,
53144,
32557,
31155,
31949,
38107,
31665,
31818,
32325,
39195,
33967,
31123,
32542,
32955,
34881,
54573,
55110,
39727,
31201,
54539,
36608,
40538,
37251,
36608,
52898,
34861,
55083,
31123,
54548,
31803,
34650,
32557,
32615,
54542,
39941,
31155,
30910,
37169,
31123,
31745,
34032,
37154,
55083,
32695,
31703,
34213,
33260,
41215,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
31745,
34032,
32419,
31745,
31695,
31123,
54617,
38107,
32419,
31818,
32325,
39195,
33967,
31155,
31645,
35949,
36073,
33339,
32955,
54703,
31745,
34032,
34732,
31123,
54548,
32444,
34650,
33116,
54542,
39941,
31155,
31002,
31007,
326,
30962,
437,
31007,
30994,
2
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
353,
1194,
4390,
7129,
383,
2603,
30932,
4610,
30932,
293,
8362,
30930,
31002,
31007,
326,
30962,
437,
31007,
30994,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
30910,
31745,
34032,
37154,
55083,
54532,
33701,
31723,
39141,
33606,
31739,
31123,
32542,
33248,
34213,
33260,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
30910,
31745,
34032,
31123,
55502,
54653,
54872,
54923,
31123,
38391,
31745,
31695,
54617,
37538,
33606,
31155,
32542,
32955,
32419,
32432,
32249,
31201,
32645,
44505,
31123,
31628,
32557,
41054,
43569,
54542,
32033,
31155,
31745,
34032,
54530,
36073,
33339,
47907,
32001,
31858,
33339,
54556,
37538,
31123,
31925,
32591,
31793,
31745,
31987,
31636,
36058,
31155,
43833,
38107,
35231,
32454,
32325,
32053,
35299,
54617,
37538,
33606,
31739,
31123,
32542,
32955,
32419,
54541,
49193,
31986,
31863,
31645,
31123,
54627,
32438,
31201,
48747,
54542,
32410,
38520,
54609,
31155,
38107,
32955,
31665,
46294,
54530,
36073,
33339,
31123,
32187,
31628,
35757,
54589,
54376,
54542,
53144,
32557,
31155,
31949,
38107,
31665,
31818,
32325,
39195,
33967,
31123,
32542,
32955,
34881,
54573,
55110,
39727,
31201,
54539,
36608,
40538,
37251,
36608,
52898,
34861,
55083,
31123,
54548,
31803,
34650,
32557,
32615,
54542,
39941,
31155,
30910,
37169,
31123,
31745,
34032,
37154,
55083,
32695,
31703,
34213,
33260,
41215,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
31745,
34032,
32419,
31745,
31695,
31123,
54617,
38107,
32419,
31818,
32325,
39195,
33967,
31155,
31645,
35949,
36073,
33339,
32955,
54703,
31745,
34032,
34732,
31123,
54548,
32444,
34650,
33116,
54542,
39941,
31155,
31002,
31007,
326,
30962,
437,
31007,
30994,
2
],
"inputs_decode": "<|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\n What are the three primary colors? <|im_end|>\n<|im_start|>assistant\n The three primary colors are red, blue, and yellow.<|im_end|> \n<|im_start|>user\n 解释个人电脑和服务器之间的区别。 <|im_end|>\n<|im_start|>assistant\n 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>",
"labels_decode": [
"The three primary colors are red, blue, and yellow.<|im_end|>",
"个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|>"
],
"seq_length": 351,
"seq_category": "None"
}

View File

@ -0,0 +1,585 @@
{
"input_ids": [
64790,
64792,
906,
31007,
13361,
31007,
30994,
13,
316,
8911,
1074,
260,
10976,
1869,
293,
284,
12060,
7151,
9319,
30930,
353,
9319,
3510,
6483,
30932,
6374,
30932,
293,
25712,
7115,
289,
267,
1869,
30953,
30917,
2554,
30930,
13,
13,
64795,
30910,
13,
265,
5011,
383,
267,
1194,
4390,
7129,
30987,
64796,
30910,
13,
265,
1036,
1194,
4390,
7129,
383,
2603,
30932,
4610,
30932,
293,
8362,
30930,
64795,
30910,
13,
265,
32929,
31745,
34032,
37154,
55083,
32695,
34213,
31155,
64796,
30910,
13,
265,
31745,
34032,
37154,
55083,
54532,
33701,
31723,
39141,
33606,
31739,
31123,
32542,
33248,
34213,
33260,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
30910,
31745,
34032,
31123,
55502,
54653,
54872,
54923,
31123,
38391,
31745,
31695,
54617,
37538,
33606,
31155,
32542,
32955,
32419,
32432,
32249,
31201,
32645,
44505,
31123,
31628,
32557,
41054,
43569,
54542,
32033,
31155,
31745,
34032,
54530,
36073,
33339,
47907,
32001,
31858,
33339,
54556,
37538,
31123,
31925,
32591,
31793,
31745,
31987,
31636,
36058,
31155,
43833,
38107,
35231,
32454,
32325,
32053,
35299,
54617,
37538,
33606,
31739,
31123,
32542,
32955,
32419,
54541,
49193,
31986,
31863,
31645,
31123,
54627,
32438,
31201,
48747,
54542,
32410,
38520,
54609,
31155,
38107,
32955,
31665,
46294,
54530,
36073,
33339,
31123,
32187,
31628,
35757,
54589,
54376,
54542,
53144,
32557,
31155,
31949,
38107,
31665,
31818,
32325,
39195,
33967,
31123,
32542,
32955,
34881,
54573,
55110,
39727,
31201,
54539,
36608,
40538,
37251,
36608,
52898,
34861,
55083,
31123,
54548,
31803,
34650,
32557,
32615,
54542,
39941,
31155,
30910,
37169,
31123,
31745,
34032,
37154,
55083,
32695,
31703,
34213,
33260,
41215,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
31745,
34032,
32419,
31745,
31695,
31123,
54617,
38107,
32419,
31818,
32325,
39195,
33967,
31155,
31645,
35949,
36073,
33339,
32955,
54703,
31745,
34032,
34732,
31123,
54548,
32444,
34650,
33116,
54542,
39941,
31155,
2
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
265,
1036,
1194,
4390,
7129,
383,
2603,
30932,
4610,
30932,
293,
8362,
30930,
64795,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
265,
31745,
34032,
37154,
55083,
54532,
33701,
31723,
39141,
33606,
31739,
31123,
32542,
33248,
34213,
33260,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
30910,
31745,
34032,
31123,
55502,
54653,
54872,
54923,
31123,
38391,
31745,
31695,
54617,
37538,
33606,
31155,
32542,
32955,
32419,
32432,
32249,
31201,
32645,
44505,
31123,
31628,
32557,
41054,
43569,
54542,
32033,
31155,
31745,
34032,
54530,
36073,
33339,
47907,
32001,
31858,
33339,
54556,
37538,
31123,
31925,
32591,
31793,
31745,
31987,
31636,
36058,
31155,
43833,
38107,
35231,
32454,
32325,
32053,
35299,
54617,
37538,
33606,
31739,
31123,
32542,
32955,
32419,
54541,
49193,
31986,
31863,
31645,
31123,
54627,
32438,
31201,
48747,
54542,
32410,
38520,
54609,
31155,
38107,
32955,
31665,
46294,
54530,
36073,
33339,
31123,
32187,
31628,
35757,
54589,
54376,
54542,
53144,
32557,
31155,
31949,
38107,
31665,
31818,
32325,
39195,
33967,
31123,
32542,
32955,
34881,
54573,
55110,
39727,
31201,
54539,
36608,
40538,
37251,
36608,
52898,
34861,
55083,
31123,
54548,
31803,
34650,
32557,
32615,
54542,
39941,
31155,
30910,
37169,
31123,
31745,
34032,
37154,
55083,
32695,
31703,
34213,
33260,
41215,
35687,
31201,
36073,
33339,
54542,
33116,
31155,
31745,
34032,
32419,
31745,
31695,
31123,
54617,
38107,
32419,
31818,
32325,
39195,
33967,
31155,
31645,
35949,
36073,
33339,
32955,
54703,
31745,
34032,
34732,
31123,
54548,
32444,
34650,
33116,
54542,
39941,
31155,
2
],
"inputs_decode": "[gMASK] sop <|system|>\n A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n <|user|> \n What are the three primary colors? <|assistant|> \n The three primary colors are red, blue, and yellow. <|user|> \n 解释个人电脑和服务器之间的区别。 <|assistant|> \n 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。",
"labels_decode": [
" The three primary colors are red, blue, and yellow. <|user|>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"
],
"seq_length": 286,
"seq_category": "None"
}

View File

@ -0,0 +1,697 @@
{
"input_ids": [
1,
92655,
92647,
1418,
92484,
13387,
92647,
92574,
24399,
5,
92343,
15161,
2357,
1346,
23815,
3558,
1377,
1452,
19649,
11656,
14002,
72,
1481,
14002,
6474,
13629,
92323,
11144,
92323,
1377,
70217,
13514,
1375,
1352,
3558,
92404,
92319,
4852,
72,
5,
5,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
5,
92655,
92647,
1418,
92484,
13387,
92647,
92574,
8589,
5,
3167,
1484,
1352,
2397,
7721,
12654,
74,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
5,
92655,
92647,
1418,
92484,
13387,
92647,
92574,
1613,
10685,
5,
1524,
2397,
7721,
12654,
1484,
5005,
92323,
8488,
92323,
1377,
16149,
72,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
5,
92655,
92647,
1418,
92484,
13387,
92647,
92574,
8589,
5,
5987,
2076,
7283,
18431,
93107,
5332,
8592,
66,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
5,
92655,
92647,
1418,
92484,
13387,
92647,
92574,
1613,
10685,
5,
2076,
7283,
18431,
93107,
92347,
8653,
2381,
22152,
9403,
2274,
65,
20962,
2231,
8592,
6795,
13087,
69,
18512,
7587,
92385,
8222,
66,
92311,
2076,
7283,
65,
52905,
65,
23044,
2076,
1964,
92492,
17274,
9403,
66,
6085,
6984,
4349,
4161,
4588,
69,
5748,
30443,
65,
1583,
5155,
24491,
48901,
92385,
3872,
66,
2076,
7283,
92333,
18512,
7587,
28068,
2592,
2417,
7587,
92393,
17274,
65,
2776,
4463,
78704,
3173,
1697,
9501,
66,
100178,
31391,
10551,
4126,
4121,
3723,
13169,
92492,
17274,
9403,
2274,
65,
6085,
6984,
4349,
57325,
2099,
2936,
2012,
1760,
65,
92467,
4516,
69,
36308,
92385,
4733,
25092,
92457,
66,
31391,
6984,
1836,
49902,
92333,
18512,
7587,
65,
3741,
1583,
11871,
92434,
55317,
92385,
41203,
5155,
66,
3035,
31391,
1836,
2590,
4121,
30881,
17136,
65,
6085,
6984,
10040,
92418,
92888,
40793,
69,
92366,
18178,
32629,
18787,
18178,
66897,
14685,
93107,
65,
73411,
12452,
5155,
82905,
23304,
66,
92311,
14748,
65,
2076,
7283,
18431,
93107,
5332,
2231,
8592,
6795,
20962,
13087,
69,
18512,
7587,
92385,
8222,
66,
2076,
7283,
4349,
2076,
1964,
65,
92492,
31391,
4349,
2590,
4121,
30881,
17136,
66,
1760,
21050,
18512,
7587,
6984,
92596,
2076,
7283,
10680,
65,
57310,
12452,
8222,
92385,
23304,
66,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
2
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
1524,
2397,
7721,
12654,
1484,
5005,
92323,
8488,
92323,
1377,
16149,
72,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
2076,
7283,
18431,
93107,
92347,
8653,
2381,
22152,
9403,
2274,
65,
20962,
2231,
8592,
6795,
13087,
69,
18512,
7587,
92385,
8222,
66,
92311,
2076,
7283,
65,
52905,
65,
23044,
2076,
1964,
92492,
17274,
9403,
66,
6085,
6984,
4349,
4161,
4588,
69,
5748,
30443,
65,
1583,
5155,
24491,
48901,
92385,
3872,
66,
2076,
7283,
92333,
18512,
7587,
28068,
2592,
2417,
7587,
92393,
17274,
65,
2776,
4463,
78704,
3173,
1697,
9501,
66,
100178,
31391,
10551,
4126,
4121,
3723,
13169,
92492,
17274,
9403,
2274,
65,
6085,
6984,
4349,
57325,
2099,
2936,
2012,
1760,
65,
92467,
4516,
69,
36308,
92385,
4733,
25092,
92457,
66,
31391,
6984,
1836,
49902,
92333,
18512,
7587,
65,
3741,
1583,
11871,
92434,
55317,
92385,
41203,
5155,
66,
3035,
31391,
1836,
2590,
4121,
30881,
17136,
65,
6085,
6984,
10040,
92418,
92888,
40793,
69,
92366,
18178,
32629,
18787,
18178,
66897,
14685,
93107,
65,
73411,
12452,
5155,
82905,
23304,
66,
92311,
14748,
65,
2076,
7283,
18431,
93107,
5332,
2231,
8592,
6795,
20962,
13087,
69,
18512,
7587,
92385,
8222,
66,
2076,
7283,
4349,
2076,
1964,
65,
92492,
31391,
4349,
2590,
4121,
30881,
17136,
66,
1760,
21050,
18512,
7587,
6984,
92596,
2076,
7283,
10680,
65,
57310,
12452,
8222,
92385,
23304,
66,
92655,
92647,
1418,
92484,
1520,
92647,
92574,
2
],
"inputs_decode": "<s> <|im_start|>system\nA chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n<|im_end|>\n<|im_start|>user\nWhat are the three primary colors?<|im_end|>\n<|im_start|>assistant\nThe three primary colors are red, blue, and yellow.<|im_end|>\n<|im_start|>user\n解释个人电脑和服务器之间的区别。<|im_end|>\n<|im_start|>assistant\n个人电脑和服务器是两种不同类型的计算机系统它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|></s>",
"labels_decode": [
"The three primary colors are red, blue, and yellow.<|im_end|>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<|im_end|></s>"
],
"seq_length": 342,
"seq_category": "None"
}

View File

@ -0,0 +1,581 @@
{
"input_ids": [
100000,
32,
12465,
1439,
245,
13076,
3807,
285,
274,
18050,
15141,
20308,
13,
429,
20308,
4380,
9394,
11,
9333,
11,
285,
30513,
9789,
276,
254,
3807,
6,
82,
4313,
13,
185,
185,
185,
185,
5726,
25,
2461,
418,
254,
1853,
6663,
8247,
30,
185,
185,
77398,
25,
429,
1853,
6663,
8247,
418,
3074,
11,
5501,
11,
285,
10421,
13,
100001,
5726,
25,
207,
17882,
6213,
19462,
50209,
4625,
17901,
25628,
398,
185,
185,
77398,
25,
207,
6213,
19462,
50209,
4625,
504,
23807,
5871,
53947,
33939,
5804,
19304,
15047,
24291,
25628,
18752,
48548,
537,
34854,
15477,
885,
16938,
398,
207,
6213,
19462,
19304,
8549,
2217,
4118,
4800,
19304,
54533,
6213,
5118,
1306,
42753,
33939,
398,
15047,
18867,
15055,
16476,
15982,
537,
15615,
885,
5782,
19304,
1876,
17135,
60574,
98238,
885,
8340,
398,
6213,
19462,
337,
34854,
15477,
72652,
10263,
7992,
15477,
913,
42753,
19304,
6117,
13742,
6857,
6213,
9384,
3421,
31448,
398,
207,
1306,
59828,
31759,
13211,
16690,
8452,
32512,
1306,
42753,
33939,
5804,
19304,
15047,
18867,
15055,
78963,
6037,
8510,
8083,
4076,
19304,
1415,
16196,
537,
12385,
53913,
885,
18775,
58875,
1537,
398,
59828,
18867,
4012,
93589,
337,
34854,
15477,
19304,
9482,
1876,
33688,
1331,
5865,
7438,
885,
88642,
17135,
398,
7406,
59828,
4012,
8580,
16690,
47322,
45098,
19304,
15047,
18867,
27061,
1059,
6273,
47500,
537,
748,
45003,
56873,
47107,
45003,
81005,
34136,
4625,
19304,
851,
8039,
32672,
17135,
13951,
885,
59649,
398,
207,
39736,
19304,
6213,
19462,
50209,
4625,
17901,
5649,
25628,
18752,
63889,
48548,
537,
34854,
15477,
885,
16938,
398,
6213,
19462,
15055,
6213,
5118,
19304,
1306,
59828,
15055,
8580,
16690,
47322,
45098,
398,
4076,
40542,
34854,
15477,
18867,
1769,
6213,
19462,
31729,
19304,
851,
12383,
32672,
16938,
885,
59649,
398,
100001
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
429,
1853,
6663,
8247,
418,
3074,
11,
5501,
11,
285,
10421,
13,
100001,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
207,
6213,
19462,
50209,
4625,
504,
23807,
5871,
53947,
33939,
5804,
19304,
15047,
24291,
25628,
18752,
48548,
537,
34854,
15477,
885,
16938,
398,
207,
6213,
19462,
19304,
8549,
2217,
4118,
4800,
19304,
54533,
6213,
5118,
1306,
42753,
33939,
398,
15047,
18867,
15055,
16476,
15982,
537,
15615,
885,
5782,
19304,
1876,
17135,
60574,
98238,
885,
8340,
398,
6213,
19462,
337,
34854,
15477,
72652,
10263,
7992,
15477,
913,
42753,
19304,
6117,
13742,
6857,
6213,
9384,
3421,
31448,
398,
207,
1306,
59828,
31759,
13211,
16690,
8452,
32512,
1306,
42753,
33939,
5804,
19304,
15047,
18867,
15055,
78963,
6037,
8510,
8083,
4076,
19304,
1415,
16196,
537,
12385,
53913,
885,
18775,
58875,
1537,
398,
59828,
18867,
4012,
93589,
337,
34854,
15477,
19304,
9482,
1876,
33688,
1331,
5865,
7438,
885,
88642,
17135,
398,
7406,
59828,
4012,
8580,
16690,
47322,
45098,
19304,
15047,
18867,
27061,
1059,
6273,
47500,
537,
748,
45003,
56873,
47107,
45003,
81005,
34136,
4625,
19304,
851,
8039,
32672,
17135,
13951,
885,
59649,
398,
207,
39736,
19304,
6213,
19462,
50209,
4625,
17901,
5649,
25628,
18752,
63889,
48548,
537,
34854,
15477,
885,
16938,
398,
6213,
19462,
15055,
6213,
5118,
19304,
1306,
59828,
15055,
8580,
16690,
47322,
45098,
398,
4076,
40542,
34854,
15477,
18867,
1769,
6213,
19462,
31729,
19304,
851,
12383,
32672,
16938,
885,
59649,
398,
100001
],
"inputs_decode": "<begin▁of▁sentence>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n\n\nUser: What are the three primary colors?\n\nAssistant: The three primary colors are red, blue, and yellow.<end▁of▁sentence>User: 解释个人电脑和服务器之间的区别。\n\nAssistant: 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<end▁of▁sentence>",
"labels_decode": [
" The three primary colors are red, blue, and yellow.<end▁of▁sentence>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。<end▁of▁sentence>"
],
"seq_length": 284,
"seq_category": "None"
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,919 @@
{
"input_ids": [
1,
733,
16289,
28793,
28705,
1824,
460,
272,
1712,
6258,
9304,
28804,
28705,
733,
28748,
16289,
28793,
415,
1712,
6258,
9304,
460,
2760,
28725,
5045,
28725,
304,
9684,
28723,
2,
733,
16289,
28793,
259,
29386,
30334,
28995,
29086,
29742,
31753,
29131,
29430,
29224,
29180,
29332,
29117,
28914,
29209,
29705,
28944,
28705,
733,
28748,
16289,
28793,
28705,
28995,
29086,
29742,
31753,
29131,
29430,
29224,
29180,
28971,
29745,
29824,
28988,
29338,
29066,
29173,
28914,
29382,
29481,
29363,
29401,
29531,
28924,
29928,
29550,
28914,
29345,
29059,
29209,
29705,
29010,
29160,
28963,
31787,
29041,
31483,
29011,
29253,
29021,
29131,
29261,
29084,
28944,
28705,
28995,
29086,
29742,
31753,
28924,
236,
164,
193,
29046,
30450,
29476,
28924,
28971,
29003,
28995,
29086,
29154,
28963,
29746,
29081,
29382,
28914,
29382,
29481,
29363,
28944,
29928,
29550,
29217,
29408,
28963,
29160,
29142,
29408,
28914,
29487,
29089,
29041,
232,
171,
180,
30488,
29131,
29500,
30978,
28924,
29052,
29074,
29798,
29037,
30344,
29824,
30344,
29675,
28914,
29298,
28963,
29265,
29313,
29131,
30239,
30727,
28944,
28995,
29086,
29742,
31753,
28914,
31483,
29011,
29253,
29021,
28969,
30812,
28971,
29518,
30131,
29144,
30168,
29253,
29021,
29263,
29081,
29382,
28914,
28924,
28988,
29202,
29537,
29052,
29074,
29521,
29020,
28995,
29086,
29259,
29243,
29258,
29037,
29018,
29251,
28944,
28705,
29746,
29430,
29224,
29180,
28971,
29003,
29105,
30493,
30303,
29050,
29195,
28963,
29106,
28914,
29259,
29243,
29746,
29081,
29382,
28914,
29382,
29481,
29363,
29401,
29531,
28924,
29928,
29550,
29217,
29408,
28963,
29160,
29003,
28963,
29106,
29279,
29954,
30344,
29824,
29533,
30229,
29430,
29224,
28924,
29118,
29533,
29992,
29041,
29742,
29169,
30182,
29011,
29131,
29019,
29011,
29359,
29226,
29414,
28944,
29430,
29224,
29180,
29217,
29408,
29259,
29059,
29366,
29261,
29084,
28914,
31483,
29011,
29253,
29021,
28924,
29457,
29958,
29052,
29074,
30761,
30154,
29366,
30685,
29435,
29131,
29495,
29007,
29117,
28914,
29798,
29037,
28944,
29590,
29160,
29430,
29224,
29180,
29259,
29059,
29428,
29569,
29050,
29195,
28963,
29106,
28914,
30164,
29562,
28924,
29928,
29550,
29217,
29408,
29253,
29613,
29292,
30080,
29326,
29093,
29180,
29041,
29050,
29329,
29195,
29188,
29079,
29131,
29050,
29329,
29195,
31483,
30358,
236,
172,
180,
29129,
29180,
28924,
29074,
29279,
29366,
29401,
29531,
28914,
29798,
29037,
30029,
29184,
29131,
234,
171,
182,
29018,
29261,
28944,
28705,
29758,
29332,
28924,
28995,
29086,
29742,
31753,
29131,
29430,
29224,
29180,
29332,
29117,
28914,
29345,
29059,
29209,
29705,
29010,
29160,
29928,
29550,
28914,
28963,
31787,
29041,
31483,
29011,
29253,
29021,
29131,
29261,
29084,
28944,
28995,
29086,
29742,
31753,
28963,
29160,
28995,
29086,
29154,
28963,
28924,
29746,
29430,
29224,
29180,
28963,
29160,
29428,
29569,
29050,
29195,
28963,
29106,
28914,
30164,
29562,
28944,
29430,
29224,
29180,
28914,
31483,
29011,
29253,
29021,
29217,
29408,
29605,
28995,
29086,
29742,
31753,
29250,
29366,
28924,
29074,
29321,
29425,
29401,
29531,
28914,
29261,
29084,
29131,
234,
171,
182,
29018,
29261,
28944,
2
],
"labels": [
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
415,
1712,
6258,
9304,
460,
2760,
28725,
5045,
28725,
304,
9684,
28723,
2,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
28705,
28995,
29086,
29742,
31753,
29131,
29430,
29224,
29180,
28971,
29745,
29824,
28988,
29338,
29066,
29173,
28914,
29382,
29481,
29363,
29401,
29531,
28924,
29928,
29550,
28914,
29345,
29059,
29209,
29705,
29010,
29160,
28963,
31787,
29041,
31483,
29011,
29253,
29021,
29131,
29261,
29084,
28944,
28705,
28995,
29086,
29742,
31753,
28924,
236,
164,
193,
29046,
30450,
29476,
28924,
28971,
29003,
28995,
29086,
29154,
28963,
29746,
29081,
29382,
28914,
29382,
29481,
29363,
28944,
29928,
29550,
29217,
29408,
28963,
29160,
29142,
29408,
28914,
29487,
29089,
29041,
232,
171,
180,
30488,
29131,
29500,
30978,
28924,
29052,
29074,
29798,
29037,
30344,
29824,
30344,
29675,
28914,
29298,
28963,
29265,
29313,
29131,
30239,
30727,
28944,
28995,
29086,
29742,
31753,
28914,
31483,
29011,
29253,
29021,
28969,
30812,
28971,
29518,
30131,
29144,
30168,
29253,
29021,
29263,
29081,
29382,
28914,
28924,
28988,
29202,
29537,
29052,
29074,
29521,
29020,
28995,
29086,
29259,
29243,
29258,
29037,
29018,
29251,
28944,
28705,
29746,
29430,
29224,
29180,
28971,
29003,
29105,
30493,
30303,
29050,
29195,
28963,
29106,
28914,
29259,
29243,
29746,
29081,
29382,
28914,
29382,
29481,
29363,
29401,
29531,
28924,
29928,
29550,
29217,
29408,
28963,
29160,
29003,
28963,
29106,
29279,
29954,
30344,
29824,
29533,
30229,
29430,
29224,
28924,
29118,
29533,
29992,
29041,
29742,
29169,
30182,
29011,
29131,
29019,
29011,
29359,
29226,
29414,
28944,
29430,
29224,
29180,
29217,
29408,
29259,
29059,
29366,
29261,
29084,
28914,
31483,
29011,
29253,
29021,
28924,
29457,
29958,
29052,
29074,
30761,
30154,
29366,
30685,
29435,
29131,
29495,
29007,
29117,
28914,
29798,
29037,
28944,
29590,
29160,
29430,
29224,
29180,
29259,
29059,
29428,
29569,
29050,
29195,
28963,
29106,
28914,
30164,
29562,
28924,
29928,
29550,
29217,
29408,
29253,
29613,
29292,
30080,
29326,
29093,
29180,
29041,
29050,
29329,
29195,
29188,
29079,
29131,
29050,
29329,
29195,
31483,
30358,
236,
172,
180,
29129,
29180,
28924,
29074,
29279,
29366,
29401,
29531,
28914,
29798,
29037,
30029,
29184,
29131,
234,
171,
182,
29018,
29261,
28944,
28705,
29758,
29332,
28924,
28995,
29086,
29742,
31753,
29131,
29430,
29224,
29180,
29332,
29117,
28914,
29345,
29059,
29209,
29705,
29010,
29160,
29928,
29550,
28914,
28963,
31787,
29041,
31483,
29011,
29253,
29021,
29131,
29261,
29084,
28944,
28995,
29086,
29742,
31753,
28963,
29160,
28995,
29086,
29154,
28963,
28924,
29746,
29430,
29224,
29180,
28963,
29160,
29428,
29569,
29050,
29195,
28963,
29106,
28914,
30164,
29562,
28944,
29430,
29224,
29180,
28914,
31483,
29011,
29253,
29021,
29217,
29408,
29605,
28995,
29086,
29742,
31753,
29250,
29366,
28924,
29074,
29321,
29425,
29401,
29531,
28914,
29261,
29084,
29131,
234,
171,
182,
29018,
29261,
28944,
2
],
"inputs_decode": "<s> [INST] What are the three primary colors? [/INST] The three primary colors are red, blue, and yellow.</s> [INST] 解释个人电脑和服务器之间的区别。 [/INST] 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。</s>",
"labels_decode": [
" The three primary colors are red, blue, and yellow.</s>",
" 个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。</s>"
],
"seq_length": 453,
"seq_category": "None"
}